{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4fa53ae1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# NLU基础"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86f86ce3-7d51-404c-b548-c884ecb54bfc",
   "metadata": {},
   "source": [
    "NLU是Natural Language Understanding的简称，即自然语言理解。一直以来都与NLG（Generation）任务并称为NLP两大主流任务。一般意义上的NLU常指与理解给定句子意思相关的意图识别、实体抽取、指代关系等任务，在智能对话中应用比较广泛。;NLU是Natural Language Understanding的简称，即自然语言理解。一直以来都与NLG（Generation）任务并称为NLP两大主流任务。一般意义上的NLU常指与理解给定句子意思相关的意图识别、实体抽取、指代关系等任务，在智能对话中应用比较广泛。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99a9393c-6e62-4db9-a6e1-0da77e3d6671",
   "metadata": {},
   "source": [
    "具体点来说，当用户输入一句话时，机器人一般会针对该句话（也可以把历史记录给附加上）进行全方面分析，包括：\n",
    "\n",
    "- 情感倾向分析：简单来说，一般会包括正向、中性、负向三种类型，也可以设计更多的类别。或更复杂的细粒度情感分析，比如针对其中某个实体或属性的情感，而不是整个句子的。\n",
    "- 意图识别：一般都是分类模型，大部分时候都是多分类，但是也有可能是层次分类，或多标签模型。\n",
    "    - 多分类：给定输入文本，输出为一个Label，但Label的总数有多个。比如类型包括询问地址、询问时间、询问价格、闲聊等等。\n",
    "    - 层次分类：给定输入文本，输出为层次的Label，也就是从根节点到最终细粒度类别的路径。比如询问地址/询问家庭地址、询问地址/询问公司地址等等。\n",
    "    - 多标签分类：给定输入文本，输出不定数量的Label，也就是说每个文本可能有多个Label，Label之间是平级关系。\n",
    "- 实体和关系抽取：\n",
    "    - 实体抽取：提取出给定文本中的实体。实体一般指具有特定意义的实词，如人名、地名、作品、品牌等等；很多时候也是业务直接相关的词。\n",
    "    - 关系抽取：实体之间往往有一定的关系，比如「刘亦菲」出演「天龙八部」，其中「刘亦菲」就是人名、「天龙八部」是作品名，其中的关系就是「出演」，一般会和实体作为三元组来表示。\n",
    "\n",
    "一般经过以上这些分析后，机器人就可以对用户的输入有一个比较清晰的理解，便于接下来据此做出响应。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e973ff2-232b-4f4f-b9ae-84d0395cc3e6",
   "metadata": {},
   "source": [
    "另外值得一提的是，上面的过程并不一定只用在对话中，只要涉及到用户输入Query需要给出响应的场景，都需要这个NLU的过程，一般也叫Query解析。\n",
    "\n",
    "上面提到的几个分析，如果从算法的角度看，其实就两种：\n",
    "\n",
    "- 句子级别的分类：如情感分析、意图识别、关系抽取等。也就是给一个句子，给出一个或多个Label。\n",
    "- Token级别的分类：如实体抽取、阅读理解（就是给一段文本和一个问题，然后在文本中找到问题的答案）。也就是给一个句子，给出对应实体的位置。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be9a4159",
   "metadata": {},
   "source": [
    "# 相关API"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25e5b52d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## LMAS GPT API"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78499e24",
   "metadata": {},
   "source": [
    "&emsp;&emsp;这里我们介绍openai的GPT接口，利用GPT大模型的In-Context能力进行Zero-Shot或Few-Shot的推理。这里有四个概念需要先稍微解释一下：\n",
    "\n",
    "- GPT：全程是Generative Pretrained Transformer，生成式预训练Transformer。大家只要知道它是一个大模型的名字即可。\n",
    "- In-Context：简单来说就是一种上下文能力，也就是模型只要根据输入的文本就可以自动给出对应的结果，这种能力是大模型在学习了非常多的文本后获得的。可以看作是一种内在的理解能力。\n",
    "- Zero-Shot：直接给模型文本，让它给出你要的标签或输出。\n",
    "- Few-Shot：给模型一些类似的Case（输入+输出），再拼上一个新的没有输出的输入，让模型给出输出。\n",
    "\n",
    "&emsp;&emsp;如果对In-Context更多细节感兴趣的，可以阅读【相关文献1】。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5aff41d",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来，我们就可以用同一个接口，只要通过构造不同的输入就可以完成不同的任务。换句话说，通过使用GPT大模型的In-Context能力，我们只需要输入的时候告诉模型我们的任务就行。\n",
    "\n",
    "&emsp;&emsp;我们看看具体的用法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1a71d0d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "# 导入自己的API key\n",
    "openai.api_key = os.environ.get(\"OPENAI_API_KEY\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b2c59486",
   "metadata": {},
   "outputs": [],
   "source": [
    "def complete(prompt):\n",
    "    response = openai.Completion.create(\n",
    "      model=\"text-davinci-003\",\n",
    "      prompt=prompt,\n",
    "      temperature=0,\n",
    "      max_tokens=64,\n",
    "      top_p=1.0,\n",
    "      frequency_penalty=0.0,\n",
    "      presence_penalty=0.0\n",
    "    )\n",
    "    ans = response.choices[0].text\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d752351c",
   "metadata": {},
   "source": [
    "&emsp;&emsp;这个是`Completion`接口，可以理解为续写，这个续写可不止能帮助我们完成一段话或一篇文章，而且可以用来做各种各样的任务，比如咱们这章要讲的分类和实体提取任务。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18b86dd7",
   "metadata": {},
   "source": [
    "&emsp;&emsp;相比上一章的`Embedding`接口，它的接口参数要复杂多了，重要的参数包括：\n",
    "\n",
    "- model：指定的模型，`text-davinci-003`就是其中一个模型，大家可以根据自己的需要，参考官方[链接](https://platform.openai.com/docs/models/gpt-3)进行选择，一般需要综合价格和效果进行权衡。\n",
    "\n",
    "- prompt：提示，默认为`<|endoftext|>`，它是模型在训练期间看到的文档分隔符，因此如果未指定Prompt，模型将像从新文档的开始一样。简单来说，就是给模型的提示语，咱们下面有例子。\n",
    "\n",
    "- max_tokens：生成的最大Token数，默认为16。注意这里的Token数不一定是字数（但对中文来说几乎一致）。Prompt+生成的文本，所有的Token长度不能超过模型的上下文长度（一般是2048，新的是4096，具体可以参考上面的链接）。\n",
    "\n",
    "- temperature：温度，默认为1。采样温度，介于0和2之间。较高的值（如0.8）将使输出更加随机，而较低的值（如0.2）将使其更加集中和确定。通常建议调整这个参数或下面的top_p，但不能同时更改两者。\n",
    "\n",
    "- top_p：采样topN分布，默认为1。0.1意味着Next Token只选择前10%概率的。\n",
    "\n",
    "- stop：停止的Token或序列，默认为null，最多4个，如果遇到该Token或序列就停止继续生成。注意生成的结果中不包含stop。\n",
    "\n",
    "- presence_penalty：存在惩罚，默认为0，介于-2.0和2.0之间的数字。正值会根据新Token到目前为止是否出现在文本中来惩罚它们，从而增加模型讨论新主题的可能性。\n",
    "\n",
    "- frequency_penalty：频率惩罚，默认为0，介于-2.0和2.0之间的数字。正值会根据新Token到目前为止在文本中的现有频率来惩罚新Token，降低模型重复生成同一行的可能性。\n",
    "\n",
    "&emsp;&emsp;更多可以参考：[API Reference - OpenAI API](https://platform.openai.com/docs/api-reference/completions/create)。在大部分情况下，我们只需考虑上面这几个参数即可，甚至大部分时候只需要前两个参数，其他的用默认也行。不过熟悉上面的参数将帮助你更好地使用API。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f89c5a4-68ec-4540-8a6e-e84ca9719feb",
   "metadata": {},
   "source": [
    "**句子级别的分类案例**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5c3bbb5a-0e8d-4e52-9190-9ae833bae322",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Technology \n",
      "\n",
      "Facebook\n",
      "Category:\n",
      "Social Media \n",
      "\n",
      "Fedex\n",
      "Category:\n",
      "Logistics and Delivery\n"
     ]
    }
   ],
   "source": [
    "# Zero-Shot的调用方式\n",
    "prompt=\"\"\"The following is a list of companies and the categories they fall into:\n",
    "\n",
    "Apple, Facebook, Fedex\n",
    "\n",
    "Apple\n",
    "Category:\n",
    "\"\"\"\n",
    "ans = complete(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8428db99-e7bb-4f6f-ae6c-210099801e25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "正向\n"
     ]
    }
   ],
   "source": [
    "# Few-Shot的调用方式\n",
    "prompt = \"\"\"今天真开心。-->正向\n",
    "心情不太好。-->负向\n",
    "我们是快乐的年轻人。-->\n",
    "\"\"\"\n",
    "ans = complete(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ebe228f-255f-4a7e-a5e7-b3b556556875",
   "metadata": {},
   "source": [
    "**Token级别的分类**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "15af41da-df96-4256-b751-7e07e355587f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Companies: AT&T, Bell South, Bell Atlantic, NYNEX, American Information Technologies, Southwestern Bell, US West, Pacific Telesis\n",
      "People & titles: William Baxter (Assistant Attorney General), Charles L. Brown (AT&T chairman), Harold H. Greene (Judge)\n"
     ]
    }
   ],
   "source": [
    "# Zero-Shot 来自openai官方文档\n",
    "prompt = \"\"\"\n",
    "From the text below, extract the following entities in the following format:\n",
    "Companies: <comma-separated list of companies mentioned>\n",
    "People & titles: <comma-separated list of people mentioned (with their titles or roles appended in parentheses)>\n",
    "\n",
    "Text:\n",
    "In March 1981, United States v. AT&T came to trial under Assistant Attorney General William Baxter. AT&T chairman Charles L. Brown thought the company would be gutted. He realized that AT&T would lose and, in December 1981, resumed negotiations with the Justice Department. Reaching an agreement less than a month later, Brown agreed to divestiture—the best and only realistic alternative. AT&T's decision allowed it to retain its research and manufacturing arms. The decree, titled the Modification of Final Judgment, was an adjustment of the Consent Decree of 14 January 1956. Judge Harold H. Greene was given the authority over the modified decree....\n",
    "\n",
    "In 1982, the U.S. government announced that AT&T would cease to exist as a monopolistic entity. On 1 January 1984, it was split into seven smaller regional companies, Bell South, Bell Atlantic, NYNEX, American Information Technologies, Southwestern Bell, US West, and Pacific Telesis, to handle regional phone services in the U.S. AT&T retains control of its long distance services, but was no longer protected from competition.\n",
    "\"\"\"\n",
    "ans = complete(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "da98a66c-f5d0-4c8d-a9ab-b935b97d7004",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "和弦：增三和弦，减三和弦\n"
     ]
    }
   ],
   "source": [
    "# Few-Shot\n",
    "prompt = \"\"\"\n",
    "根据下面的格式抽取给定Text中的实体:\n",
    "和弦: <实体用逗号分割>\n",
    "\n",
    "Text:\n",
    "三和弦是由3个按照三度音程关系排列起来的一组音。大三和弦是大三度+小三度的纯五度音，小三和弦是小三度+大三度的纯五度音。\n",
    "和弦：大三和弦，小三和弦\n",
    "\n",
    "Text:\n",
    "增三和弦是大三度+大三度的增五度音，减三和弦是小三度+小三度的减五度音。\n",
    "\"\"\"\n",
    "ans = complete(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1e623de",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## ChatGPT的风格"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "126691a4",
   "metadata": {},
   "source": [
    "&emsp;&emsp;这个是`ChatCompletions`接口，可以理解为对话（也就是ChatGPT），几乎可以做任意的NLP任务。它的参数和`Completion`类似，我们依然介绍主要参数：\n",
    "\n",
    "- model：指定的模型，`gpt-3.5-turbo`就是ChatGPT，大家还是可以根据实际情况参考官方给出的[列表](https://platform.openai.com/docs/models/gpt-3)选择合适的模型。\n",
    "\n",
    "- messages：会话消息，支持多轮，多轮就是多条。每一条消息为一个字典，包含「role」和「content」两个字段。如：`[{\"role\": \"user\", \"content\": \"Hello!\"}]`\n",
    "\n",
    "- temperature：和`Completion`接口含义一样。\n",
    "\n",
    "- top_p：和`Completion`接口含义一样。\n",
    "\n",
    "- stop：和`Completion`接口含义一样。\n",
    "\n",
    "- max_tokens：默认无上限，其他和`Completion`接口含义一样，也受限于模型的最大上下文长度。\n",
    "\n",
    "- presence_penalty：和`Completion`接口含义一样。\n",
    "\n",
    "- frequency_penalty：和`Completion`接口含义一样。\n",
    "\n",
    "&emsp;&emsp;更多可以参考：[API Reference - OpenAI API](https://platform.openai.com/docs/api-reference/completions/create)，值得再次一提的是，接口支持多轮，而且多轮非常简单，只需要把历史会话加进去就可以了。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18a14955",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来，我们采用ChatGPT方式来做类似的任务。这个输入看起来和上面是比较类似的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a237b8ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask(content):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-3.5-turbo\", \n",
    "        messages=[{\"role\": \"user\", \"content\": content}]\n",
    "    )\n",
    "\n",
    "    ans = response.get(\"choices\")[0].get(\"message\").get(\"content\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de730e4e",
   "metadata": {},
   "source": [
    "&emsp;&emsp;我们依次尝试上面的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8afb90eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt=\"\"\"The following is a list of companies and the categories they fall into:\n",
    "\n",
    "Apple, Facebook, Fedex\n",
    "\n",
    "Apple\n",
    "Category:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5d8dae58",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Technology/Electronics\n",
      "\n",
      "Facebook\n",
      "Category:\n",
      "Social Media/Technology\n",
      "\n",
      "Fedex\n",
      "Category:\n",
      "Shipping/Logistics/Postal delivery\n"
     ]
    }
   ],
   "source": [
    "ans = ask(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2401d7d",
   "metadata": {},
   "source": [
    "&emsp;&emsp;嗯，效果也是类似的，不过在ChatGPT这里我们可以更加精简一些："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "147d8d45",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt=\"\"\"please output the category of the following companies:\n",
    "Apple, Facebook, Fedex\n",
    "\n",
    "The output format should be:\n",
    "<company>\n",
    "Category:\n",
    "<category>\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "72293253",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Apple\n",
      "Category:\n",
      "Technology/Electronics\n",
      "\n",
      "Facebook\n",
      "Category:\n",
      "Technology/Social Media\n",
      "\n",
      "Fedex\n",
      "Category:\n",
      "Logistics/Transportation\n"
     ]
    }
   ],
   "source": [
    "ans = ask(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f899e4d8",
   "metadata": {},
   "source": [
    "&emsp;&emsp;Great! ChatGPT比前面的GPT API更加「聪明」，交互更加自然。我们可以尝试以对话的方式来让它帮忙完成任务。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83d12d87",
   "metadata": {},
   "source": [
    "&emsp;&emsp;实际上，关于这个领域早就有了一个成熟的技术方案：Prompt工程，大家可以进一步阅读【相关文献2】。这里给出一些常见的建议：\n",
    "\n",
    "- 清晰，切忌复杂或歧义，如果有术语，应定义清楚。\n",
    "- 具体，描述语言应尽量具体，不要抽象活模棱两可。\n",
    "- 聚焦，问题避免太泛或开放。\n",
    "- 简洁，避免不必要的描述。\n",
    "- 相关，主要指主题相关，而且是整个对话期间，不要东一瓢西一瓤。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63c3e846",
   "metadata": {},
   "source": [
    "&emsp;&emsp;新手容易忽略的地方：\n",
    "\n",
    "- 没有说明具体的输出目标。\n",
    "- 在一次对话中混合多个主题。\n",
    "- 让语言模型做数学题。\n",
    "- 没有给出想要什么的示例样本。\n",
    "- 反向提示。也就是一些反面例子。\n",
    "- 要求他一次只做一件事。可以将步骤捆绑在一起一次说清，不要拆的太碎。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abc85191",
   "metadata": {},
   "source": [
    "&emsp;&emsp;我们来试一下情感分类的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5aa98dc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"请给出下面句子的情感倾向，情感倾向包括三种：正向、中性、负向。\n",
    "句子：我们是快乐的年轻人。\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cdae4a08",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "情感倾向：正向\n"
     ]
    }
   ],
   "source": [
    "ans = ask(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c94d95c7",
   "metadata": {},
   "source": [
    "&emsp;&emsp;再来做一下实体的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4b00d82b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"\n",
    "请抽取给定Text中的实体，实体包括Company和People&Title，对于People，请同时给出它们的Title或role，跟在实体后面，用括号括起。\n",
    "\n",
    "Text:\n",
    "In March 1981, United States v. AT&T came to trial under Assistant Attorney General William Baxter. AT&T chairman Charles L. Brown thought the company would be gutted. He realized that AT&T would lose and, in December 1981, resumed negotiations with the Justice Department. Reaching an agreement less than a month later, Brown agreed to divestiture—the best and only realistic alternative. AT&T's decision allowed it to retain its research and manufacturing arms. The decree, titled the Modification of Final Judgment, was an adjustment of the Consent Decree of 14 January 1956. Judge Harold H. Greene was given the authority over the modified decree....\n",
    "\n",
    "In 1982, the U.S. government announced that AT&T would cease to exist as a monopolistic entity. On 1 January 1984, it was split into seven smaller regional companies, Bell South, Bell Atlantic, NYNEX, American Information Technologies, Southwestern Bell, US West, and Pacific Telesis, to handle regional phone services in the U.S. AT&T retains control of its long distance services, but was no longer protected from competition.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e848b6b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "实体：\n",
      "- Company: AT&T, Bell South, Bell Atlantic, NYNEX, American Information Technologies, Southwestern Bell, US West, Pacific Telesis\n",
      "- People & Title: William Baxter (Assistant Attorney General), Charles L. Brown (AT&T chairman), Judge Harold H. Greene\n"
     ]
    }
   ],
   "source": [
    "ans = ask(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4526fb4b",
   "metadata": {},
   "source": [
    "&emsp;&emsp;看起来还行。我们最后试一下上面的另一个例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "18272b0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"\n",
    "根据下面的格式抽取给定Text中的「和弦」实体，应包括「和弦」两个字\n",
    "\n",
    "Desired format:\n",
    "和弦：<用逗号隔开>\n",
    "\n",
    "Text:\n",
    "增三和弦是大三度+大三度的增五度音，减三和弦是小三度+小三度的减五度音。\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "280ff25e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "和弦：增三和弦,减三和弦\n"
     ]
    }
   ],
   "source": [
    "ans = ask(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25c4ec7f",
   "metadata": {},
   "source": [
    "&emsp;&emsp;我们还用了中英文混合，它也完全没问题。\n",
    "\n",
    "&emsp;&emsp;大家不妨多多尝试，也可以参考【相关文献 2-10】中的写法，总的来说，它并没有什么标准答案，最多也是一种习惯或约定。大家可以自由尝试，不用有任何负担。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d2c7c9f",
   "metadata": {
    "tags": []
   },
   "source": [
    "# NLU应用"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c36147c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## 文档问答"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96c06e33",
   "metadata": {},
   "source": [
    "&emsp;&emsp;文档问答和上一章的QA有点类似，不过要稍微复杂一点。它会先用QA的方法召回一个相关的文档，然后让模型在这个文档中找出问题的答案。一般的流程还是先召回相关文档，然后做阅读理解任务。阅读理解和实体提取任务有些类似，但它预测的不是具体某个标签，而是答案的Index，即start和end的位置。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59853e21",
   "metadata": {},
   "source": [
    "&emsp;&emsp;还是举个例子。假设我们的问题是：“北京奥运会举办于哪一年？”\n",
    "\n",
    "&emsp;&emsp;召回的文档可能是含有北京奥运会举办的新闻，比如类似下面这样的：\n",
    "\n",
    "> 第29届夏季奥林匹克运动会（Beijing 2008; Games of the XXIX Olympiad），又称2008年北京奥运会，2008年8月8日晚上8时整在中国首都北京开幕。8月24日闭幕。\n",
    "\n",
    "&emsp;&emsp;标注就是「2008年」这个答案的索引。\n",
    "\n",
    "&emsp;&emsp;当然，一个文档里可能有不止一个问题，比如上面的文档，还可以问：“北京奥运会啥时候开幕？”，“北京奥运会什么时候闭幕”，“北京奥运会是第几届奥运会”等问题。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db4aeb4c",
   "metadata": {},
   "source": [
    "&emsp;&emsp;根据之前的NLP方法，这里实际做起来方案会比较多，也有一定的复杂度；不过总的来说还是分类任务。现在我们有了LLM，问题就变得简单了。依然是两步：\n",
    "\n",
    "1. 召回：与上一章的QA类似，这次召回的是Doc，这一步其实就是相似Embedding选择最相似的。\n",
    "2. 回答：将召回来的文档和问题以Prompt的方式提交给Completion/ChatCompletion接口，直接得到答案。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8eb4bd97",
   "metadata": {},
   "source": [
    "&emsp;&emsp;我们分别用两种不同的接口各举一例，首先看看`Completion`接口："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dba3529a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "# 导入自己的API key\n",
    "openai.api_key = os.environ.get(\"OPENAI_API_KEY\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0d48a421",
   "metadata": {},
   "outputs": [],
   "source": [
    "def complete(prompt):\n",
    "    response = openai.Completion.create(\n",
    "        prompt=prompt,\n",
    "        temperature=0,\n",
    "        max_tokens=300,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0,\n",
    "        model=\"text-davinci-003\"\n",
    "    )\n",
    "    ans = response[\"choices\"][0][\"text\"].strip(\" \\n\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "16b79d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 来自官方文档\n",
    "prompt = \"\"\"Answer the question as truthfully as possible using the provided text, and if the answer is not contained within the text below, say \"I don't know\"\n",
    "\n",
    "Context:\n",
    "The men's high jump event at the 2020 Summer Olympics took place between 30 July and 1 August 2021 at the Olympic Stadium.\n",
    "33 athletes from 24 nations competed; the total possible number depended on how many nations would use universality places \n",
    "to enter athletes in addition to the 32 qualifying through mark or ranking (no universality places were used in 2021).\n",
    "Italian athlete Gianmarco Tamberi along with Qatari athlete Mutaz Essa Barshim emerged as joint winners of the event following\n",
    "a tie between both of them as they cleared 2.37m. Both Tamberi and Barshim agreed to share the gold medal in a rare instance\n",
    "where the athletes of different nations had agreed to share the same medal in the history of Olympics. \n",
    "Barshim in particular was heard to ask a competition official \"Can we have two golds?\" in response to being offered a \n",
    "'jump off'. Maksim Nedasekau of Belarus took bronze. The medals were the first ever in the men's high jump for Italy and \n",
    "Belarus, the first gold in the men's high jump for Italy and Qatar, and the third consecutive medal in the men's high jump\n",
    "for Qatar (all by Barshim). Barshim became only the second man to earn three medals in high jump, joining Patrik Sjöberg\n",
    "of Sweden (1984 to 1992).\n",
    "\n",
    "Q: Who won the 2020 Summer Olympics men's high jump?\n",
    "A:\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "ae2966ea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Gianmarco Tamberi and Mutaz Essa Barshim emerged as joint winners of the event.'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "complete(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bf7fc74",
   "metadata": {},
   "source": [
    "&emsp;&emsp;上面的Context就是我们召回的文档。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00e01352",
   "metadata": {},
   "source": [
    "&emsp;&emsp;再看`ChatCompletion`接口："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "32875d6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"请根据以下Context回答问题，直接输出答案即可，不用附带任何上下文。\n",
    "\n",
    "Context:\n",
    "诺曼人（诺曼人：Nourmands；法语：Normands；拉丁语：Normanni）是在10世纪和11世纪将名字命名为法国诺曼底的人。他们是北欧人的后裔（丹麦人，挪威人和挪威人）的海盗和海盗，他们在首相罗洛（Rollo）的领导下向西弗朗西亚国王查理三世宣誓效忠。经过几代人的同化，并与法兰克和罗马高卢人本地居民融合，他们的后代将逐渐与以西卡罗来纳州为基础的加洛林人文化融合。诺曼人独特的文化和种族身份最初出现于10世纪上半叶，并在随后的几个世纪中持续发展。\n",
    "\n",
    "问题：\n",
    "诺曼底在哪个国家/地区？\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5b4e76fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask(content):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-3.5-turbo\", \n",
    "        messages=[{\"role\": \"user\", \"content\": content}]\n",
    "    )\n",
    "\n",
    "    ans = response.get(\"choices\")[0].get(\"message\").get(\"content\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "5f49e938",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "法国。\n"
     ]
    }
   ],
   "source": [
    "ans = ask(prompt)\n",
    "print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5767e7ba",
   "metadata": {},
   "source": [
    "&emsp;&emsp;看起来还行，我们接下来就把整个流程串起来，先用`Completion`接口实现（便宜），不过也很方便替换过去，毕竟输入都不变（都是Prompt）。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "051ef1d2",
   "metadata": {},
   "source": [
    "&emsp;&emsp;首先是加载数据集，取自：[openai-cookbook/olympics-1-collect-data.ipynb at 1f6c2304b401e931928e74e978d9a0b8a40d1cf7 · openai/openai-cookbook](https://github.com/openai/openai-cookbook/blob/1f6c2304b401e931928e74e978d9a0b8a40d1cf7/examples/fine-tuned_qa/olympics-1-collect-data.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "773d036f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3964, 4)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv(\"data/olympics_sections_text.csv\")\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ae22998",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>title</th>\n",
       "      <th>heading</th>\n",
       "      <th>content</th>\n",
       "      <th>tokens</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2020 Summer Olympics</td>\n",
       "      <td>Summary</td>\n",
       "      <td>The 2020 Summer Olympics (Japanese: 2020年夏季オリン...</td>\n",
       "      <td>726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2020 Summer Olympics</td>\n",
       "      <td>Host city selection</td>\n",
       "      <td>The International Olympic Committee (IOC) vote...</td>\n",
       "      <td>126</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2020 Summer Olympics</td>\n",
       "      <td>Impact of the COVID-19 pandemic</td>\n",
       "      <td>In January 2020, concerns were raised about th...</td>\n",
       "      <td>374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2020 Summer Olympics</td>\n",
       "      <td>Qualifying event cancellation and postponement</td>\n",
       "      <td>Concerns about the pandemic began to affect qu...</td>\n",
       "      <td>298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2020 Summer Olympics</td>\n",
       "      <td>Effect on doping tests</td>\n",
       "      <td>Mandatory doping tests were being severely res...</td>\n",
       "      <td>163</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  title                                         heading  \\\n",
       "0  2020 Summer Olympics                                         Summary   \n",
       "1  2020 Summer Olympics                             Host city selection   \n",
       "2  2020 Summer Olympics                 Impact of the COVID-19 pandemic   \n",
       "3  2020 Summer Olympics  Qualifying event cancellation and postponement   \n",
       "4  2020 Summer Olympics                          Effect on doping tests   \n",
       "\n",
       "                                             content  tokens  \n",
       "0  The 2020 Summer Olympics (Japanese: 2020年夏季オリン...     726  \n",
       "1  The International Olympic Committee (IOC) vote...     126  \n",
       "2  In January 2020, concerns were raised about th...     374  \n",
       "3  Concerns about the pandemic began to affect qu...     298  \n",
       "4  Mandatory doping tests were being severely res...     163  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe4abb41",
   "metadata": {},
   "source": [
    "&emsp;&emsp;我们这次不用Redis，换一个工具：[Qdrant - Vector Search Engine](https://qdrant.tech/)，Qdrant相比Redis的单线程更容易扩展。但我们切记，要根据实际情况选择工具，很多时候过度优化是原罪，适合的就是最好的。我们真正需要做的是将业务逻辑抽象，做到尽量不依赖任何工具，换工具只需要换一个适配器就好。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eddffa43",
   "metadata": {},
   "source": [
    "&emsp;&emsp;依然使用Docker，启动很简单：\n",
    "\n",
    "```shell\n",
    "docker run -p 6333:6333 -v $(pwd)/qdrant_storage:/qdrant/storage qdrant/qdrant`\n",
    "```\n",
    "\n",
    "&emsp;&emsp;自然也少不了客户端的安装：\n",
    "\n",
    "```shell\n",
    "pip install qdrant-client\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e2d26a2",
   "metadata": {},
   "source": [
    "&emsp;&emsp;不过首先还是生成Embedding，这一步可以使用`get_embedding`接口："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "40d58791",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai.embeddings_utils import get_embedding, cosine_similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23ddce38",
   "metadata": {},
   "source": [
    "&emsp;&emsp;或者也可以直接使用原生的`Embedding`接口，还支持多条一次请求："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a9607aa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embedding_direct(inputs):\n",
    "    embed_model = \"text-embedding-ada-002\"\n",
    "\n",
    "    res = openai.Embedding.create(\n",
    "        input=inputs, engine=embed_model\n",
    "    )\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c2671bef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3964"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "texts = [v.content for v in df.itertuples()]\n",
    "len(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "80405446",
   "metadata": {},
   "outputs": [
    {
     "ename": "APIConnectionError",
     "evalue": "Error communicating with OpenAI: HTTPSConnectionPool(host='api.openai.com', port=443): Max retries exceeded with url: /v1/engines/text-embedding-ada-002/embeddings (Caused by SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol (_ssl.c:1129)')))",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mSSLEOFError\u001b[0m                               Traceback (most recent call last)",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\connectionpool.py:700\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[1;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[0;32m    699\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_new_proxy_conn \u001b[38;5;129;01mand\u001b[39;00m http_tunnel_required:\n\u001b[1;32m--> 700\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_proxy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    702\u001b[0m \u001b[38;5;66;03m# Make the request on the httplib connection object.\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\connectionpool.py:996\u001b[0m, in \u001b[0;36mHTTPSConnectionPool._prepare_proxy\u001b[1;34m(self, conn)\u001b[0m\n\u001b[0;32m    994\u001b[0m     conn\u001b[38;5;241m.\u001b[39mtls_in_tls_required \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m--> 996\u001b[0m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\connection.py:369\u001b[0m, in \u001b[0;36mHTTPSConnection.connect\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    368\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtls_in_tls_required:\n\u001b[1;32m--> 369\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msock \u001b[38;5;241m=\u001b[39m conn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_connect_tls_proxy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhostname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    370\u001b[0m     tls_in_tls \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\connection.py:504\u001b[0m, in \u001b[0;36mHTTPSConnection._connect_tls_proxy\u001b[1;34m(self, hostname, conn)\u001b[0m\n\u001b[0;32m    502\u001b[0m \u001b[38;5;66;03m# If no cert was provided, use only the default options for server\u001b[39;00m\n\u001b[0;32m    503\u001b[0m \u001b[38;5;66;03m# certificate validation\u001b[39;00m\n\u001b[1;32m--> 504\u001b[0m socket \u001b[38;5;241m=\u001b[39m \u001b[43mssl_wrap_socket\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    505\u001b[0m \u001b[43m    \u001b[49m\u001b[43msock\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    506\u001b[0m \u001b[43m    \u001b[49m\u001b[43mca_certs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mca_certs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    507\u001b[0m \u001b[43m    \u001b[49m\u001b[43mca_cert_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mca_cert_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    508\u001b[0m \u001b[43m    \u001b[49m\u001b[43mca_cert_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mca_cert_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    509\u001b[0m \u001b[43m    \u001b[49m\u001b[43mserver_hostname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhostname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    510\u001b[0m \u001b[43m    \u001b[49m\u001b[43mssl_context\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mssl_context\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    511\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    513\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ssl_context\u001b[38;5;241m.\u001b[39mverify_mode \u001b[38;5;241m!=\u001b[39m ssl\u001b[38;5;241m.\u001b[39mCERT_NONE \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(\n\u001b[0;32m    514\u001b[0m     ssl_context, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcheck_hostname\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m    515\u001b[0m ):\n\u001b[0;32m    516\u001b[0m     \u001b[38;5;66;03m# While urllib3 attempts to always turn off hostname matching from\u001b[39;00m\n\u001b[0;32m    517\u001b[0m     \u001b[38;5;66;03m# the TLS library, this cannot always be done. So we check whether\u001b[39;00m\n\u001b[0;32m    518\u001b[0m     \u001b[38;5;66;03m# the TLS Library still thinks it's matching hostnames.\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\util\\ssl_.py:453\u001b[0m, in \u001b[0;36mssl_wrap_socket\u001b[1;34m(sock, keyfile, certfile, cert_reqs, ca_certs, server_hostname, ssl_version, ciphers, ssl_context, ca_cert_dir, key_password, ca_cert_data, tls_in_tls)\u001b[0m\n\u001b[0;32m    452\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 453\u001b[0m     ssl_sock \u001b[38;5;241m=\u001b[39m \u001b[43m_ssl_wrap_socket_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43msock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtls_in_tls\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    454\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ssl_sock\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\util\\ssl_.py:495\u001b[0m, in \u001b[0;36m_ssl_wrap_socket_impl\u001b[1;34m(sock, ssl_context, tls_in_tls, server_hostname)\u001b[0m\n\u001b[0;32m    494\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 495\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mssl_context\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrap_socket\u001b[49m\u001b[43m(\u001b[49m\u001b[43msock\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\ssl.py:500\u001b[0m, in \u001b[0;36mSSLContext.wrap_socket\u001b[1;34m(self, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, session)\u001b[0m\n\u001b[0;32m    494\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrap_socket\u001b[39m(\u001b[38;5;28mself\u001b[39m, sock, server_side\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m    495\u001b[0m                 do_handshake_on_connect\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m    496\u001b[0m                 suppress_ragged_eofs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m    497\u001b[0m                 server_hostname\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, session\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m    498\u001b[0m     \u001b[38;5;66;03m# SSLSocket class handles server_hostname encoding before it calls\u001b[39;00m\n\u001b[0;32m    499\u001b[0m     \u001b[38;5;66;03m# ctx._wrap_socket()\u001b[39;00m\n\u001b[1;32m--> 500\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msslsocket_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    501\u001b[0m \u001b[43m        \u001b[49m\u001b[43msock\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msock\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    502\u001b[0m \u001b[43m        \u001b[49m\u001b[43mserver_side\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_side\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    503\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdo_handshake_on_connect\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdo_handshake_on_connect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    504\u001b[0m \u001b[43m        \u001b[49m\u001b[43msuppress_ragged_eofs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msuppress_ragged_eofs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    505\u001b[0m \u001b[43m        \u001b[49m\u001b[43mserver_hostname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_hostname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    506\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    507\u001b[0m \u001b[43m        \u001b[49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msession\u001b[49m\n\u001b[0;32m    508\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\ssl.py:1040\u001b[0m, in \u001b[0;36mSSLSocket._create\u001b[1;34m(cls, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, context, session)\u001b[0m\n\u001b[0;32m   1039\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdo_handshake_on_connect should not be specified for non-blocking sockets\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m-> 1040\u001b[0m         \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdo_handshake\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1041\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mOSError\u001b[39;00m, \u001b[38;5;167;01mValueError\u001b[39;00m):\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\ssl.py:1309\u001b[0m, in \u001b[0;36mSSLSocket.do_handshake\u001b[1;34m(self, block)\u001b[0m\n\u001b[0;32m   1308\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msettimeout(\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m-> 1309\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sslobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdo_handshake\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1310\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n",
      "\u001b[1;31mSSLEOFError\u001b[0m: EOF occurred in violation of protocol (_ssl.c:1129)",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mMaxRetryError\u001b[0m                             Traceback (most recent call last)",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\requests\\adapters.py:440\u001b[0m, in \u001b[0;36mHTTPAdapter.send\u001b[1;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[0;32m    439\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m chunked:\n\u001b[1;32m--> 440\u001b[0m     resp \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    441\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    442\u001b[0m \u001b[43m        \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    443\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    444\u001b[0m \u001b[43m        \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    445\u001b[0m \u001b[43m        \u001b[49m\u001b[43mredirect\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    446\u001b[0m \u001b[43m        \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    447\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpreload_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    448\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdecode_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    449\u001b[0m \u001b[43m        \u001b[49m\u001b[43mretries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_retries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    450\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\n\u001b[0;32m    451\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    453\u001b[0m \u001b[38;5;66;03m# Send the request.\u001b[39;00m\n\u001b[0;32m    454\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[1;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[0;32m    812\u001b[0m     log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[0;32m    813\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[0;32m    814\u001b[0m     )\n\u001b[1;32m--> 815\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39murlopen(\n\u001b[0;32m    816\u001b[0m         method,\n\u001b[0;32m    817\u001b[0m         url,\n\u001b[0;32m    818\u001b[0m         body,\n\u001b[0;32m    819\u001b[0m         headers,\n\u001b[0;32m    820\u001b[0m         retries,\n\u001b[0;32m    821\u001b[0m         redirect,\n\u001b[0;32m    822\u001b[0m         assert_same_host,\n\u001b[0;32m    823\u001b[0m         timeout\u001b[38;5;241m=\u001b[39mtimeout,\n\u001b[0;32m    824\u001b[0m         pool_timeout\u001b[38;5;241m=\u001b[39mpool_timeout,\n\u001b[0;32m    825\u001b[0m         release_conn\u001b[38;5;241m=\u001b[39mrelease_conn,\n\u001b[0;32m    826\u001b[0m         chunked\u001b[38;5;241m=\u001b[39mchunked,\n\u001b[0;32m    827\u001b[0m         body_pos\u001b[38;5;241m=\u001b[39mbody_pos,\n\u001b[0;32m    828\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mresponse_kw\n\u001b[0;32m    829\u001b[0m     )\n\u001b[0;32m    831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[1;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[0;32m    812\u001b[0m     log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[0;32m    813\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[0;32m    814\u001b[0m     )\n\u001b[1;32m--> 815\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39murlopen(\n\u001b[0;32m    816\u001b[0m         method,\n\u001b[0;32m    817\u001b[0m         url,\n\u001b[0;32m    818\u001b[0m         body,\n\u001b[0;32m    819\u001b[0m         headers,\n\u001b[0;32m    820\u001b[0m         retries,\n\u001b[0;32m    821\u001b[0m         redirect,\n\u001b[0;32m    822\u001b[0m         assert_same_host,\n\u001b[0;32m    823\u001b[0m         timeout\u001b[38;5;241m=\u001b[39mtimeout,\n\u001b[0;32m    824\u001b[0m         pool_timeout\u001b[38;5;241m=\u001b[39mpool_timeout,\n\u001b[0;32m    825\u001b[0m         release_conn\u001b[38;5;241m=\u001b[39mrelease_conn,\n\u001b[0;32m    826\u001b[0m         chunked\u001b[38;5;241m=\u001b[39mchunked,\n\u001b[0;32m    827\u001b[0m         body_pos\u001b[38;5;241m=\u001b[39mbody_pos,\n\u001b[0;32m    828\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mresponse_kw\n\u001b[0;32m    829\u001b[0m     )\n\u001b[0;32m    831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\connectionpool.py:787\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[1;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[0;32m    785\u001b[0m     e \u001b[38;5;241m=\u001b[39m ProtocolError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConnection aborted.\u001b[39m\u001b[38;5;124m\"\u001b[39m, e)\n\u001b[1;32m--> 787\u001b[0m retries \u001b[38;5;241m=\u001b[39m \u001b[43mretries\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mincrement\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    788\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_pool\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_stacktrace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msys\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexc_info\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[0;32m    789\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    790\u001b[0m retries\u001b[38;5;241m.\u001b[39msleep()\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\urllib3\\util\\retry.py:592\u001b[0m, in \u001b[0;36mRetry.increment\u001b[1;34m(self, method, url, response, error, _pool, _stacktrace)\u001b[0m\n\u001b[0;32m    591\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_retry\u001b[38;5;241m.\u001b[39mis_exhausted():\n\u001b[1;32m--> 592\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m MaxRetryError(_pool, url, error \u001b[38;5;129;01mor\u001b[39;00m ResponseError(cause))\n\u001b[0;32m    594\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIncremented Retry for (url=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m): \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, url, new_retry)\n",
      "\u001b[1;31mMaxRetryError\u001b[0m: HTTPSConnectionPool(host='api.openai.com', port=443): Max retries exceeded with url: /v1/engines/text-embedding-ada-002/embeddings (Caused by SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol (_ssl.c:1129)')))",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mSSLError\u001b[0m                                  Traceback (most recent call last)",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\openai\\api_requestor.py:516\u001b[0m, in \u001b[0;36mAPIRequestor.request_raw\u001b[1;34m(self, method, url, params, supplied_headers, files, stream, request_id, request_timeout)\u001b[0m\n\u001b[0;32m    515\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 516\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[43m_thread_context\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    517\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    518\u001b[0m \u001b[43m        \u001b[49m\u001b[43mabs_url\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    519\u001b[0m \u001b[43m        \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    520\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    521\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfiles\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    522\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    523\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest_timeout\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrequest_timeout\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mTIMEOUT_SECS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    524\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    525\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m requests\u001b[38;5;241m.\u001b[39mexceptions\u001b[38;5;241m.\u001b[39mTimeout \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\requests\\sessions.py:529\u001b[0m, in \u001b[0;36mSession.request\u001b[1;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[0m\n\u001b[0;32m    528\u001b[0m send_kwargs\u001b[38;5;241m.\u001b[39mupdate(settings)\n\u001b[1;32m--> 529\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msend(prep, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39msend_kwargs)\n\u001b[0;32m    531\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\requests\\sessions.py:645\u001b[0m, in \u001b[0;36mSession.send\u001b[1;34m(self, request, **kwargs)\u001b[0m\n\u001b[0;32m    644\u001b[0m \u001b[38;5;66;03m# Send the request\u001b[39;00m\n\u001b[1;32m--> 645\u001b[0m r \u001b[38;5;241m=\u001b[39m adapter\u001b[38;5;241m.\u001b[39msend(request, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m    647\u001b[0m \u001b[38;5;66;03m# Total elapsed time of the request (approximately)\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\requests\\adapters.py:517\u001b[0m, in \u001b[0;36mHTTPAdapter.send\u001b[1;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[0;32m    515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e\u001b[38;5;241m.\u001b[39mreason, _SSLError):\n\u001b[0;32m    516\u001b[0m     \u001b[38;5;66;03m# This branch is for urllib3 v1.22 and later.\u001b[39;00m\n\u001b[1;32m--> 517\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m SSLError(e, request\u001b[38;5;241m=\u001b[39mrequest)\n\u001b[0;32m    519\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(e, request\u001b[38;5;241m=\u001b[39mrequest)\n",
      "\u001b[1;31mSSLError\u001b[0m: HTTPSConnectionPool(host='api.openai.com', port=443): Max retries exceeded with url: /v1/engines/text-embedding-ada-002/embeddings (Caused by SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol (_ssl.c:1129)')))",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[1;31mAPIConnectionError\u001b[0m                        Traceback (most recent call last)",
      "Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 3>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      1\u001b[0m emds \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m----> 3\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mget_embedding_direct\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtexts\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m20\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# 只取20获取embedding\u001b[39;00m\n\u001b[0;32m      4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m response\u001b[38;5;241m.\u001b[39mdata:\n\u001b[0;32m      5\u001b[0m     emds\u001b[38;5;241m.\u001b[39mappend(v\u001b[38;5;241m.\u001b[39membedding)\n",
      "Input \u001b[1;32mIn [5]\u001b[0m, in \u001b[0;36mget_embedding_direct\u001b[1;34m(inputs)\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_embedding_direct\u001b[39m(inputs):\n\u001b[0;32m      2\u001b[0m     embed_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext-embedding-ada-002\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m----> 4\u001b[0m     res \u001b[38;5;241m=\u001b[39m \u001b[43mopenai\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mEmbedding\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m      5\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mengine\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membed_model\u001b[49m\n\u001b[0;32m      6\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      7\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m res\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\openai\\api_resources\\embedding.py:33\u001b[0m, in \u001b[0;36mEmbedding.create\u001b[1;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[0;32m     31\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m     32\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 33\u001b[0m         response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mcreate(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m     35\u001b[0m         \u001b[38;5;66;03m# If a user specifies base64, we'll just return the encoded string.\u001b[39;00m\n\u001b[0;32m     36\u001b[0m         \u001b[38;5;66;03m# This is only for the default case.\u001b[39;00m\n\u001b[0;32m     37\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m user_provided_encoding_format:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\openai\\api_resources\\abstract\\engine_api_resource.py:153\u001b[0m, in \u001b[0;36mEngineAPIResource.create\u001b[1;34m(cls, api_key, api_base, api_type, request_id, api_version, organization, **params)\u001b[0m\n\u001b[0;32m    127\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[0;32m    128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate\u001b[39m(\n\u001b[0;32m    129\u001b[0m     \u001b[38;5;28mcls\u001b[39m,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    136\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams,\n\u001b[0;32m    137\u001b[0m ):\n\u001b[0;32m    138\u001b[0m     (\n\u001b[0;32m    139\u001b[0m         deployment_id,\n\u001b[0;32m    140\u001b[0m         engine,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    150\u001b[0m         api_key, api_base, api_type, api_version, organization, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams\n\u001b[0;32m    151\u001b[0m     )\n\u001b[1;32m--> 153\u001b[0m     response, _, api_key \u001b[38;5;241m=\u001b[39m \u001b[43mrequestor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    154\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpost\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    155\u001b[0m \u001b[43m        \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    156\u001b[0m \u001b[43m        \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    157\u001b[0m \u001b[43m        \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    158\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    159\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrequest_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    160\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrequest_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    161\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    163\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m stream:\n\u001b[0;32m    164\u001b[0m         \u001b[38;5;66;03m# must be an iterator\u001b[39;00m\n\u001b[0;32m    165\u001b[0m         \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(response, OpenAIResponse)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\openai\\api_requestor.py:216\u001b[0m, in \u001b[0;36mAPIRequestor.request\u001b[1;34m(self, method, url, params, headers, files, stream, request_id, request_timeout)\u001b[0m\n\u001b[0;32m    205\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrequest\u001b[39m(\n\u001b[0;32m    206\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m    207\u001b[0m     method,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    214\u001b[0m     request_timeout: Optional[Union[\u001b[38;5;28mfloat\u001b[39m, Tuple[\u001b[38;5;28mfloat\u001b[39m, \u001b[38;5;28mfloat\u001b[39m]]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m    215\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mstr\u001b[39m]:\n\u001b[1;32m--> 216\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest_raw\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    217\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlower\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    218\u001b[0m \u001b[43m        \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    219\u001b[0m \u001b[43m        \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    220\u001b[0m \u001b[43m        \u001b[49m\u001b[43msupplied_headers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    221\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfiles\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    222\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    223\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrequest_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    224\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrequest_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    225\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    226\u001b[0m     resp, got_stream \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_interpret_response(result, stream)\n\u001b[0;32m    227\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m resp, got_stream, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapi_key\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\openai\\api_requestor.py:528\u001b[0m, in \u001b[0;36mAPIRequestor.request_raw\u001b[1;34m(self, method, url, params, supplied_headers, files, stream, request_id, request_timeout)\u001b[0m\n\u001b[0;32m    526\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mTimeout(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRequest timed out: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(e)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[0;32m    527\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m requests\u001b[38;5;241m.\u001b[39mexceptions\u001b[38;5;241m.\u001b[39mRequestException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m--> 528\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mAPIConnectionError(\n\u001b[0;32m    529\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mError communicating with OpenAI: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(e)\n\u001b[0;32m    530\u001b[0m     ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[0;32m    531\u001b[0m util\u001b[38;5;241m.\u001b[39mlog_debug(\n\u001b[0;32m    532\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOpenAI API response\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m    533\u001b[0m     path\u001b[38;5;241m=\u001b[39mabs_url,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    536\u001b[0m     request_id\u001b[38;5;241m=\u001b[39mresult\u001b[38;5;241m.\u001b[39mheaders\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX-Request-Id\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[0;32m    537\u001b[0m )\n\u001b[0;32m    538\u001b[0m \u001b[38;5;66;03m# Don't read the whole stream for debug logging unless necessary.\u001b[39;00m\n",
      "\u001b[1;31mAPIConnectionError\u001b[0m: Error communicating with OpenAI: HTTPSConnectionPool(host='api.openai.com', port=443): Max retries exceeded with url: /v1/engines/text-embedding-ada-002/embeddings (Caused by SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol (_ssl.c:1129)')))"
     ]
    }
   ],
   "source": [
    "emds = []\n",
    "for idx, batch in enumerate(pnlp.generate_batches_by_size(texts, 200)):\n",
    "    response = get_embedding_direct(batch)\n",
    "    for v in response.data:\n",
    "        emds.append(v.embedding)\n",
    "    print(f\"batch: {idx} done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24103d0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(emds), len(emds[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6782ac47",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来是创建索引："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0bd69901",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qdrant_client import QdrantClient\n",
    "client = QdrantClient(host=\"localhost\", port=6333)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a53361e6",
   "metadata": {},
   "source": [
    "&emsp;&emsp;值得注意的是，qdrant还支持内存/文件库，也就是说，可以直接："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "08bc0316",
   "metadata": {},
   "outputs": [],
   "source": [
    "# client = QdrantClient(\":memory:\")\n",
    "# 或\n",
    "# client = QdrantClient(path=\"path/to/db\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d458f00e",
   "metadata": {},
   "source": [
    "&emsp;&emsp;我们还是用server的方式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b4e9fd89",
   "metadata": {},
   "outputs": [
    {
     "ename": "ResponseHandlingException",
     "evalue": "EOF occurred in violation of protocol (_ssl.c:1129)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mSSLEOFError\u001b[0m                               Traceback (most recent call last)",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_exceptions.py:10\u001b[0m, in \u001b[0;36mmap_exceptions\u001b[1;34m(map)\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 10\u001b[0m     \u001b[38;5;28;01myield\u001b[39;00m\n\u001b[0;32m     11\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:  \u001b[38;5;66;03m# noqa: PIE786\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\backends\\sync.py:62\u001b[0m, in \u001b[0;36mSyncStream.start_tls\u001b[1;34m(self, ssl_context, server_hostname, timeout)\u001b[0m\n\u001b[0;32m     61\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclose()\n\u001b[1;32m---> 62\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m exc\n\u001b[0;32m     63\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m SyncStream(sock)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\backends\\sync.py:57\u001b[0m, in \u001b[0;36mSyncStream.start_tls\u001b[1;34m(self, ssl_context, server_hostname, timeout)\u001b[0m\n\u001b[0;32m     56\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sock\u001b[38;5;241m.\u001b[39msettimeout(timeout)\n\u001b[1;32m---> 57\u001b[0m     sock \u001b[38;5;241m=\u001b[39m \u001b[43mssl_context\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrap_socket\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m     58\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mserver_hostname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_hostname\u001b[49m\n\u001b[0;32m     59\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     60\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:  \u001b[38;5;66;03m# pragma: nocover\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\ssl.py:500\u001b[0m, in \u001b[0;36mSSLContext.wrap_socket\u001b[1;34m(self, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, session)\u001b[0m\n\u001b[0;32m    494\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrap_socket\u001b[39m(\u001b[38;5;28mself\u001b[39m, sock, server_side\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m    495\u001b[0m                 do_handshake_on_connect\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m    496\u001b[0m                 suppress_ragged_eofs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m    497\u001b[0m                 server_hostname\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, session\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m    498\u001b[0m     \u001b[38;5;66;03m# SSLSocket class handles server_hostname encoding before it calls\u001b[39;00m\n\u001b[0;32m    499\u001b[0m     \u001b[38;5;66;03m# ctx._wrap_socket()\u001b[39;00m\n\u001b[1;32m--> 500\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msslsocket_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    501\u001b[0m \u001b[43m        \u001b[49m\u001b[43msock\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msock\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    502\u001b[0m \u001b[43m        \u001b[49m\u001b[43mserver_side\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_side\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    503\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdo_handshake_on_connect\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdo_handshake_on_connect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    504\u001b[0m \u001b[43m        \u001b[49m\u001b[43msuppress_ragged_eofs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msuppress_ragged_eofs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    505\u001b[0m \u001b[43m        \u001b[49m\u001b[43mserver_hostname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_hostname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    506\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    507\u001b[0m \u001b[43m        \u001b[49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msession\u001b[49m\n\u001b[0;32m    508\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\ssl.py:1040\u001b[0m, in \u001b[0;36mSSLSocket._create\u001b[1;34m(cls, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, context, session)\u001b[0m\n\u001b[0;32m   1039\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdo_handshake_on_connect should not be specified for non-blocking sockets\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m-> 1040\u001b[0m         \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdo_handshake\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1041\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mOSError\u001b[39;00m, \u001b[38;5;167;01mValueError\u001b[39;00m):\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\ssl.py:1309\u001b[0m, in \u001b[0;36mSSLSocket.do_handshake\u001b[1;34m(self, block)\u001b[0m\n\u001b[0;32m   1308\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msettimeout(\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m-> 1309\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sslobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdo_handshake\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1310\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n",
      "\u001b[1;31mSSLEOFError\u001b[0m: EOF occurred in violation of protocol (_ssl.c:1129)",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mConnectError\u001b[0m                              Traceback (most recent call last)",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_transports\\default.py:60\u001b[0m, in \u001b[0;36mmap_httpcore_exceptions\u001b[1;34m()\u001b[0m\n\u001b[0;32m     59\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 60\u001b[0m     \u001b[38;5;28;01myield\u001b[39;00m\n\u001b[0;32m     61\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:  \u001b[38;5;66;03m# noqa: PIE-786\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_transports\\default.py:218\u001b[0m, in \u001b[0;36mHTTPTransport.handle_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m    217\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m map_httpcore_exceptions():\n\u001b[1;32m--> 218\u001b[0m     resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_pool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhandle_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreq\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    220\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(resp\u001b[38;5;241m.\u001b[39mstream, typing\u001b[38;5;241m.\u001b[39mIterable)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_sync\\connection_pool.py:253\u001b[0m, in \u001b[0;36mConnectionPool.handle_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m    252\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresponse_closed(status)\n\u001b[1;32m--> 253\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m exc\n\u001b[0;32m    254\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_sync\\connection_pool.py:237\u001b[0m, in \u001b[0;36mConnectionPool.handle_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m    236\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 237\u001b[0m     response \u001b[38;5;241m=\u001b[39m \u001b[43mconnection\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhandle_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    238\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConnectionNotAvailable:\n\u001b[0;32m    239\u001b[0m     \u001b[38;5;66;03m# The ConnectionNotAvailable exception is a special case, that\u001b[39;00m\n\u001b[0;32m    240\u001b[0m     \u001b[38;5;66;03m# indicates we need to retry the request on a new connection.\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    244\u001b[0m     \u001b[38;5;66;03m# might end up as an HTTP/2 connection, but which actually ends\u001b[39;00m\n\u001b[0;32m    245\u001b[0m     \u001b[38;5;66;03m# up as HTTP/1.1.\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_sync\\http_proxy.py:261\u001b[0m, in \u001b[0;36mTunnelHTTPConnection.handle_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m    255\u001b[0m connect_request \u001b[38;5;241m=\u001b[39m Request(\n\u001b[0;32m    256\u001b[0m     method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCONNECT\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m    257\u001b[0m     url\u001b[38;5;241m=\u001b[39mconnect_url,\n\u001b[0;32m    258\u001b[0m     headers\u001b[38;5;241m=\u001b[39mconnect_headers,\n\u001b[0;32m    259\u001b[0m     extensions\u001b[38;5;241m=\u001b[39mrequest\u001b[38;5;241m.\u001b[39mextensions,\n\u001b[0;32m    260\u001b[0m )\n\u001b[1;32m--> 261\u001b[0m connect_response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_connection\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhandle_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    262\u001b[0m \u001b[43m    \u001b[49m\u001b[43mconnect_request\u001b[49m\n\u001b[0;32m    263\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    265\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m connect_response\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m200\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m connect_response\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m299\u001b[39m:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_sync\\connection.py:86\u001b[0m, in \u001b[0;36mHTTPConnection.handle_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m     85\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_connect_failed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m---> 86\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m exc\n\u001b[0;32m     87\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_connection\u001b[38;5;241m.\u001b[39mis_available():\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_sync\\connection.py:63\u001b[0m, in \u001b[0;36mHTTPConnection.handle_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m     62\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 63\u001b[0m     stream \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_connect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     65\u001b[0m     ssl_object \u001b[38;5;241m=\u001b[39m stream\u001b[38;5;241m.\u001b[39mget_extra_info(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mssl_object\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_sync\\connection.py:150\u001b[0m, in \u001b[0;36mHTTPConnection._connect\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m    149\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m Trace(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mconnection.start_tls\u001b[39m\u001b[38;5;124m\"\u001b[39m, request, kwargs) \u001b[38;5;28;01mas\u001b[39;00m trace:\n\u001b[1;32m--> 150\u001b[0m     stream \u001b[38;5;241m=\u001b[39m stream\u001b[38;5;241m.\u001b[39mstart_tls(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m    151\u001b[0m     trace\u001b[38;5;241m.\u001b[39mreturn_value \u001b[38;5;241m=\u001b[39m stream\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\backends\\sync.py:62\u001b[0m, in \u001b[0;36mSyncStream.start_tls\u001b[1;34m(self, ssl_context, server_hostname, timeout)\u001b[0m\n\u001b[0;32m     61\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclose()\n\u001b[1;32m---> 62\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m exc\n\u001b[0;32m     63\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m SyncStream(sock)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\contextlib.py:137\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__exit__\u001b[1;34m(self, typ, value, traceback)\u001b[0m\n\u001b[0;32m    136\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 137\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mthrow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtyp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraceback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    138\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[0;32m    139\u001b[0m     \u001b[38;5;66;03m# Suppress StopIteration *unless* it's the same exception that\u001b[39;00m\n\u001b[0;32m    140\u001b[0m     \u001b[38;5;66;03m# was passed to throw().  This prevents a StopIteration\u001b[39;00m\n\u001b[0;32m    141\u001b[0m     \u001b[38;5;66;03m# raised inside the \"with\" statement from being suppressed.\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpcore\\_exceptions.py:14\u001b[0m, in \u001b[0;36mmap_exceptions\u001b[1;34m(map)\u001b[0m\n\u001b[0;32m     13\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(exc, from_exc):\n\u001b[1;32m---> 14\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m to_exc(exc)\n\u001b[0;32m     15\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n",
      "\u001b[1;31mConnectError\u001b[0m: EOF occurred in violation of protocol (_ssl.c:1129)",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[1;31mConnectError\u001b[0m                              Traceback (most recent call last)",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\http\\api_client.py:95\u001b[0m, in \u001b[0;36mApiClient.send_inner\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m     94\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 95\u001b[0m     response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     96\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_client.py:901\u001b[0m, in \u001b[0;36mClient.send\u001b[1;34m(self, request, stream, auth, follow_redirects)\u001b[0m\n\u001b[0;32m    899\u001b[0m auth \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_request_auth(request, auth)\n\u001b[1;32m--> 901\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_handling_auth\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    902\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    903\u001b[0m \u001b[43m    \u001b[49m\u001b[43mauth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mauth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    904\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfollow_redirects\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfollow_redirects\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    905\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhistory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    906\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    907\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_client.py:929\u001b[0m, in \u001b[0;36mClient._send_handling_auth\u001b[1;34m(self, request, auth, follow_redirects, history)\u001b[0m\n\u001b[0;32m    928\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m--> 929\u001b[0m     response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_handling_redirects\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    930\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    931\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfollow_redirects\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfollow_redirects\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    932\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhistory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhistory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    933\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    934\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_client.py:966\u001b[0m, in \u001b[0;36mClient._send_handling_redirects\u001b[1;34m(self, request, follow_redirects, history)\u001b[0m\n\u001b[0;32m    964\u001b[0m     hook(request)\n\u001b[1;32m--> 966\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_single_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    967\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_client.py:1002\u001b[0m, in \u001b[0;36mClient._send_single_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m   1001\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m request_context(request\u001b[38;5;241m=\u001b[39mrequest):\n\u001b[1;32m-> 1002\u001b[0m     response \u001b[38;5;241m=\u001b[39m \u001b[43mtransport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhandle_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1004\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(response\u001b[38;5;241m.\u001b[39mstream, SyncByteStream)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_transports\\default.py:218\u001b[0m, in \u001b[0;36mHTTPTransport.handle_request\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m    217\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m map_httpcore_exceptions():\n\u001b[1;32m--> 218\u001b[0m     resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pool\u001b[38;5;241m.\u001b[39mhandle_request(req)\n\u001b[0;32m    220\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(resp\u001b[38;5;241m.\u001b[39mstream, typing\u001b[38;5;241m.\u001b[39mIterable)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\contextlib.py:137\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__exit__\u001b[1;34m(self, typ, value, traceback)\u001b[0m\n\u001b[0;32m    136\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 137\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mthrow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtyp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraceback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    138\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[0;32m    139\u001b[0m     \u001b[38;5;66;03m# Suppress StopIteration *unless* it's the same exception that\u001b[39;00m\n\u001b[0;32m    140\u001b[0m     \u001b[38;5;66;03m# was passed to throw().  This prevents a StopIteration\u001b[39;00m\n\u001b[0;32m    141\u001b[0m     \u001b[38;5;66;03m# raised inside the \"with\" statement from being suppressed.\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\httpx\\_transports\\default.py:77\u001b[0m, in \u001b[0;36mmap_httpcore_exceptions\u001b[1;34m()\u001b[0m\n\u001b[0;32m     76\u001b[0m message \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(exc)\n\u001b[1;32m---> 77\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m mapped_exc(message) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mexc\u001b[39;00m\n",
      "\u001b[1;31mConnectError\u001b[0m: EOF occurred in violation of protocol (_ssl.c:1129)",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mResponseHandlingException\u001b[0m                 Traceback (most recent call last)",
      "Input \u001b[1;32mIn [12]\u001b[0m, in \u001b[0;36m<cell line: 3>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mqdrant_client\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodels\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Distance, VectorParams\n\u001b[1;32m----> 3\u001b[0m \u001b[43mclient\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecreate_collection\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m      4\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcollection_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdoc_qa\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m      5\u001b[0m \u001b[43m    \u001b[49m\u001b[43mvectors_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mVectorParams\u001b[49m\u001b[43m(\u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1536\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdistance\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mDistance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mCOSINE\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m      6\u001b[0m \u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\qdrant_client.py:1040\u001b[0m, in \u001b[0;36mQdrantClient.recreate_collection\u001b[1;34m(self, collection_name, vectors_config, shard_number, replication_factor, write_consistency_factor, on_disk_payload, hnsw_config, optimizers_config, wal_config, quantization_config, init_from, timeout, **kwargs)\u001b[0m\n\u001b[0;32m    986\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrecreate_collection\u001b[39m(\n\u001b[0;32m    987\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m    988\u001b[0m     collection_name: \u001b[38;5;28mstr\u001b[39m,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1000\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[0;32m   1001\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mbool\u001b[39m:\n\u001b[0;32m   1002\u001b[0m     \u001b[38;5;124;03m\"\"\"Delete and create empty collection with given parameters\u001b[39;00m\n\u001b[0;32m   1003\u001b[0m \n\u001b[0;32m   1004\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1037\u001b[0m \u001b[38;5;124;03m        Operation result\u001b[39;00m\n\u001b[0;32m   1038\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m-> 1040\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_client\u001b[38;5;241m.\u001b[39mrecreate_collection(\n\u001b[0;32m   1041\u001b[0m         collection_name\u001b[38;5;241m=\u001b[39mcollection_name,\n\u001b[0;32m   1042\u001b[0m         vectors_config\u001b[38;5;241m=\u001b[39mvectors_config,\n\u001b[0;32m   1043\u001b[0m         shard_number\u001b[38;5;241m=\u001b[39mshard_number,\n\u001b[0;32m   1044\u001b[0m         replication_factor\u001b[38;5;241m=\u001b[39mreplication_factor,\n\u001b[0;32m   1045\u001b[0m         write_consistency_factor\u001b[38;5;241m=\u001b[39mwrite_consistency_factor,\n\u001b[0;32m   1046\u001b[0m         on_disk_payload\u001b[38;5;241m=\u001b[39mon_disk_payload,\n\u001b[0;32m   1047\u001b[0m         hnsw_config\u001b[38;5;241m=\u001b[39mhnsw_config,\n\u001b[0;32m   1048\u001b[0m         optimizers_config\u001b[38;5;241m=\u001b[39moptimizers_config,\n\u001b[0;32m   1049\u001b[0m         wal_config\u001b[38;5;241m=\u001b[39mwal_config,\n\u001b[0;32m   1050\u001b[0m         quantization_config\u001b[38;5;241m=\u001b[39mquantization_config,\n\u001b[0;32m   1051\u001b[0m         init_from\u001b[38;5;241m=\u001b[39minit_from,\n\u001b[0;32m   1052\u001b[0m         timeout\u001b[38;5;241m=\u001b[39mtimeout,\n\u001b[0;32m   1053\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[0;32m   1054\u001b[0m     )\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\qdrant_remote.py:1811\u001b[0m, in \u001b[0;36mQdrantRemote.recreate_collection\u001b[1;34m(self, collection_name, vectors_config, shard_number, replication_factor, write_consistency_factor, on_disk_payload, hnsw_config, optimizers_config, wal_config, quantization_config, init_from, timeout, **kwargs)\u001b[0m\n\u001b[0;32m   1757\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrecreate_collection\u001b[39m(\n\u001b[0;32m   1758\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m   1759\u001b[0m     collection_name: \u001b[38;5;28mstr\u001b[39m,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1771\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[0;32m   1772\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mbool\u001b[39m:\n\u001b[0;32m   1773\u001b[0m     \u001b[38;5;124;03m\"\"\"Delete and create empty collection with given parameters\u001b[39;00m\n\u001b[0;32m   1774\u001b[0m \n\u001b[0;32m   1775\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1808\u001b[0m \u001b[38;5;124;03m        Operation result\u001b[39;00m\n\u001b[0;32m   1809\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m-> 1811\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete_collection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcollection_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1813\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcreate_collection(\n\u001b[0;32m   1814\u001b[0m         collection_name\u001b[38;5;241m=\u001b[39mcollection_name,\n\u001b[0;32m   1815\u001b[0m         vectors_config\u001b[38;5;241m=\u001b[39mvectors_config,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1824\u001b[0m         timeout\u001b[38;5;241m=\u001b[39mtimeout,\n\u001b[0;32m   1825\u001b[0m     )\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\qdrant_remote.py:1664\u001b[0m, in \u001b[0;36mQdrantRemote.delete_collection\u001b[1;34m(self, collection_name, timeout, **kwargs)\u001b[0m\n\u001b[0;32m   1650\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete_collection\u001b[39m(\n\u001b[0;32m   1651\u001b[0m     \u001b[38;5;28mself\u001b[39m, collection_name: \u001b[38;5;28mstr\u001b[39m, timeout: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any\n\u001b[0;32m   1652\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mbool\u001b[39m:\n\u001b[0;32m   1653\u001b[0m     \u001b[38;5;124;03m\"\"\"Removes collection and all it's data\u001b[39;00m\n\u001b[0;32m   1654\u001b[0m \n\u001b[0;32m   1655\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1662\u001b[0m \u001b[38;5;124;03m        Operation result\u001b[39;00m\n\u001b[0;32m   1663\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m-> 1664\u001b[0m     result: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhttp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollections_api\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete_collection\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1665\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcollection_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\n\u001b[0;32m   1666\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mresult\n\u001b[0;32m   1667\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDelete collection returned None\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   1668\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\http\\api\\collections_api.py:866\u001b[0m, in \u001b[0;36mSyncCollectionsApi.delete_collection\u001b[1;34m(self, collection_name, timeout)\u001b[0m\n\u001b[0;32m    858\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete_collection\u001b[39m(\n\u001b[0;32m    859\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m    860\u001b[0m     collection_name: \u001b[38;5;28mstr\u001b[39m,\n\u001b[0;32m    861\u001b[0m     timeout: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m    862\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m m\u001b[38;5;241m.\u001b[39mInlineResponse2003:\n\u001b[0;32m    863\u001b[0m     \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    864\u001b[0m \u001b[38;5;124;03m    Drop collection and all associated data\u001b[39;00m\n\u001b[0;32m    865\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m--> 866\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_build_for_delete_collection\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    867\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcollection_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcollection_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    868\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    869\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\http\\api\\collections_api.py:274\u001b[0m, in \u001b[0;36m_CollectionsApi._build_for_delete_collection\u001b[1;34m(self, collection_name, timeout)\u001b[0m\n\u001b[0;32m    271\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    272\u001b[0m     query_params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(timeout)\n\u001b[1;32m--> 274\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapi_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    275\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtype_\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInlineResponse2003\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    276\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mDELETE\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    277\u001b[0m \u001b[43m    \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m/collections/\u001b[39;49m\u001b[38;5;132;43;01m{collection_name}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    278\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpath_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    279\u001b[0m \u001b[43m    \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    280\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\http\\api_client.py:68\u001b[0m, in \u001b[0;36mApiClient.request\u001b[1;34m(self, type_, method, url, path_params, **kwargs)\u001b[0m\n\u001b[0;32m     66\u001b[0m url \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhost \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m+\u001b[39m url\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpath_params)\n\u001b[0;32m     67\u001b[0m request \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_client\u001b[38;5;241m.\u001b[39mbuild_request(method, url, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m---> 68\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtype_\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\http\\api_client.py:85\u001b[0m, in \u001b[0;36mApiClient.send\u001b[1;34m(self, request, type_)\u001b[0m\n\u001b[0;32m     84\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msend\u001b[39m(\u001b[38;5;28mself\u001b[39m, request: Request, type_: Type[T]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m---> 85\u001b[0m     response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmiddleware\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend_inner\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     86\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m response\u001b[38;5;241m.\u001b[39mstatus_code \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;241m200\u001b[39m, \u001b[38;5;241m201\u001b[39m]:\n\u001b[0;32m     87\u001b[0m         \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\http\\api_client.py:188\u001b[0m, in \u001b[0;36mBaseMiddleware.__call__\u001b[1;34m(self, request, call_next)\u001b[0m\n\u001b[0;32m    187\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, request: Request, call_next: Send) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Response:\n\u001b[1;32m--> 188\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcall_next\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\qdrant_client\\http\\api_client.py:97\u001b[0m, in \u001b[0;36mApiClient.send_inner\u001b[1;34m(self, request)\u001b[0m\n\u001b[0;32m     95\u001b[0m     response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_client\u001b[38;5;241m.\u001b[39msend(request)\n\u001b[0;32m     96\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m---> 97\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m ResponseHandlingException(e)\n\u001b[0;32m     98\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n",
      "\u001b[1;31mResponseHandlingException\u001b[0m: EOF occurred in violation of protocol (_ssl.c:1129)"
     ]
    }
   ],
   "source": [
    "from qdrant_client.models import Distance, VectorParams\n",
    "\n",
    "client.recreate_collection(\n",
    "    collection_name=\"doc_qa\",\n",
    "    vectors_config=VectorParams(size=1536, distance=Distance.COSINE),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "b64fec60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# client.delete_collection(\"doc_qa\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f87b97d5",
   "metadata": {},
   "source": [
    "&emsp;&emsp;然后是把向量入库："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "3094fc8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "payload=[\n",
    "    {\"content\": v.content, \"heading\": v.heading, \"title\": v.title, \"tokens\": v.tokens} for v in df.itertuples()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "a3f4ecfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "client.upload_collection(\n",
    "    collection_name=\"doc_qa\",\n",
    "    vectors=emds,\n",
    "    payload=payload\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3c0db92",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来进行查询："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "f061868f",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"Who won the 2020 Summer Olympics men's high jump?\"\n",
    "query_vector = get_embedding(query, engine=\"text-embedding-ada-002\")\n",
    "hits = client.search(\n",
    "    collection_name=\"doc_qa\",\n",
    "    query_vector=query_vector,\n",
    "    limit=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "90ca55da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ScoredPoint(id=236, version=3, score=0.90316474, payload={'content': 'The men\\'s high jump event at the 2020 Summer Olympics took place between 30 July and 1 August 2021 at the Olympic Stadium. 33 athletes from 24 nations competed; the total possible number depended on how many nations would use universality places to enter athletes in addition to the 32 qualifying through mark or ranking (no universality places were used in 2021). Italian athlete Gianmarco Tamberi along with Qatari athlete Mutaz Essa Barshim emerged as joint winners of the event following a tie between both of them as they cleared 2.37m. Both Tamberi and Barshim agreed to share the gold medal in a rare instance where the athletes of different nations had agreed to share the same medal in the history of Olympics. Barshim in particular was heard to ask a competition official \"Can we have two golds?\" in response to being offered a \\'jump off\\'. Maksim Nedasekau of Belarus took bronze. The medals were the first ever in the men\\'s high jump for Italy and Belarus, the first gold in the men\\'s high jump for Italy and Qatar, and the third consecutive medal in the men\\'s high jump for Qatar (all by Barshim). Barshim became only the second man to earn three medals in high jump, joining Patrik Sjöberg of Sweden (1984 to 1992).', 'heading': 'Summary', 'title': \"Athletics at the 2020 Summer Olympics – Men's high jump\", 'tokens': 275}, vector=None),\n",
       " ScoredPoint(id=313, version=4, score=0.88258004, payload={'content': \"The men's long jump event at the 2020 Summer Olympics took place between 31 July and 2 August 2021 at the Japan National Stadium. Approximately 35 athletes were expected to compete; the exact number was dependent on how many nations use universality places to enter athletes in addition to the 32 qualifying through time or ranking (1 universality place was used in 2016). 31 athletes from 20 nations competed. Miltiadis Tentoglou won the gold medal, Greece's first medal in the men's long jump. Cuban athletes Juan Miguel Echevarría and Maykel Massó earned silver and bronze, respectively, the nation's first medals in the event since 2008.\", 'heading': 'Summary', 'title': \"Athletics at the 2020 Summer Olympics – Men's long jump\", 'tokens': 136}, vector=None),\n",
       " ScoredPoint(id=284, version=4, score=0.8821836, payload={'content': \"The men's pole vault event at the 2020 Summer Olympics took place between 31 July and 3 August 2021 at the Japan National Stadium. 29 athletes from 18 nations competed. Armand Duplantis of Sweden won gold, with Christopher Nilsen of the United States earning silver and Thiago Braz of Brazil taking bronze. It was Sweden's first victory in the event and first medal of any color in the men's pole vault since 1952. Braz, who had won in 2016, became the ninth man to earn multiple medals in the pole vault.\", 'heading': 'Summary', 'title': \"Athletics at the 2020 Summer Olympics – Men's pole vault\", 'tokens': 112}, vector=None),\n",
       " ScoredPoint(id=222, version=3, score=0.876395, payload={'content': \"The men's triple jump event at the 2020 Summer Olympics took place between 3 and 5 August 2021 at the Japan National Stadium. Approximately 35 athletes were expected to compete; the exact number was dependent on how many nations use universality places to enter athletes in addition to the 32 qualifying through time or ranking (2 universality places were used in 2016). 32 athletes from 19 nations competed. Pedro Pichardo of Portugal won the gold medal, the nation's second victory in the men's triple jump (after Nelson Évora in 2008). China's Zhu Yaming took silver, while Hugues Fabrice Zango earned Burkina Faso's first Olympic medal in any event.\", 'heading': 'Summary', 'title': \"Athletics at the 2020 Summer Olympics – Men's triple jump\", 'tokens': 139}, vector=None),\n",
       " ScoredPoint(id=205, version=3, score=0.86075026, payload={'content': \"The men's 110 metres hurdles event at the 2020 Summer Olympics took place between 3 and 5 August 2021 at the Olympic Stadium. Approximately forty athletes were expected to compete; the exact number was dependent on how many nations used universality places to enter athletes in addition to the 40 qualifying through time or ranking (1 universality place was used in 2016). 40 athletes from 29 nations competed. Hansle Parchment of Jamaica won the gold medal, the nation's second consecutive victory in the event. His countryman Ronald Levy took bronze. American Grant Holloway earned silver, placing the United States back on the podium in the event after the nation missed the medals for the first time in Rio 2016 (excluding the boycotted 1980 Games).\", 'heading': 'Summary', 'title': \"Athletics at the 2020 Summer Olympics – Men's 110 metres hurdles\", 'tokens': 149}, vector=None)]"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba00d42e",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来将这个过程包装在Prompt生成过程中："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "id": "a28101c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SECTION_LEN = 500\n",
    "SEPARATOR = \"\\n* \"\n",
    "separator_len = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "0dde385a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_prompt(question: str):\n",
    "    query_vector = get_embedding(question, engine=\"text-embedding-ada-002\")\n",
    "    hits = client.search(\n",
    "        collection_name=\"doc_qa\",\n",
    "        query_vector=query_vector,\n",
    "        limit=5\n",
    "    )\n",
    "    \n",
    "    choose = []\n",
    "    length = 0\n",
    "    indexes = []\n",
    "     \n",
    "    for hit in hits:\n",
    "        doc = hit.payload\n",
    "        length += doc[\"tokens\"] + separator_len\n",
    "        if length > MAX_SECTION_LEN:\n",
    "            break\n",
    "            \n",
    "        choose.append(SEPARATOR + doc[\"content\"].replace(\"\\n\", \" \"))\n",
    "        indexes.append(doc[\"title\"] + doc[\"heading\"])\n",
    "            \n",
    "    # Useful diagnostic information\n",
    "    print(f\"Selected {len(choose)} document sections:\")\n",
    "    print(\"\\n\".join(indexes))\n",
    "    \n",
    "    header = \"\"\"Answer the question as truthfully as possible using the provided context, and if the answer is not contained within the text below, say \"I don't know.\"\\n\\nContext:\\n\"\"\"\n",
    "    \n",
    "    return header + \"\".join(choose) + \"\\n\\n Q: \" + question + \"\\n A:\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "id": "30cc5008",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected 2 document sections:\n",
      "Athletics at the 2020 Summer Olympics – Men's high jumpSummary\n",
      "Athletics at the 2020 Summer Olympics – Men's long jumpSummary\n",
      "===\n",
      " Answer the question as truthfully as possible using the provided context, and if the answer is not contained within the text below, say \"I don't know.\"\n",
      "\n",
      "Context:\n",
      "\n",
      "* The men's high jump event at the 2020 Summer Olympics took place between 30 July and 1 August 2021 at the Olympic Stadium. 33 athletes from 24 nations competed; the total possible number depended on how many nations would use universality places to enter athletes in addition to the 32 qualifying through mark or ranking (no universality places were used in 2021). Italian athlete Gianmarco Tamberi along with Qatari athlete Mutaz Essa Barshim emerged as joint winners of the event following a tie between both of them as they cleared 2.37m. Both Tamberi and Barshim agreed to share the gold medal in a rare instance where the athletes of different nations had agreed to share the same medal in the history of Olympics. Barshim in particular was heard to ask a competition official \"Can we have two golds?\" in response to being offered a 'jump off'. Maksim Nedasekau of Belarus took bronze. The medals were the first ever in the men's high jump for Italy and Belarus, the first gold in the men's high jump for Italy and Qatar, and the third consecutive medal in the men's high jump for Qatar (all by Barshim). Barshim became only the second man to earn three medals in high jump, joining Patrik Sjöberg of Sweden (1984 to 1992).\n",
      "* The men's long jump event at the 2020 Summer Olympics took place between 31 July and 2 August 2021 at the Japan National Stadium. Approximately 35 athletes were expected to compete; the exact number was dependent on how many nations use universality places to enter athletes in addition to the 32 qualifying through time or ranking (1 universality place was used in 2016). 31 athletes from 20 nations competed. Miltiadis Tentoglou won the gold medal, Greece's first medal in the men's long jump. Cuban athletes Juan Miguel Echevarría and Maykel Massó earned silver and bronze, respectively, the nation's first medals in the event since 2008.\n",
      "\n",
      " Q: Who won the 2020 Summer Olympics men's high jump?\n",
      " A:\n"
     ]
    }
   ],
   "source": [
    "prompt = construct_prompt(\"Who won the 2020 Summer Olympics men's high jump?\")\n",
    "\n",
    "print(\"===\\n\", prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "id": "b52801aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def complete(prompt):\n",
    "    response = openai.Completion.create(\n",
    "        prompt=prompt,\n",
    "        temperature=0,\n",
    "        max_tokens=300,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0,\n",
    "        model=\"text-davinci-003\"\n",
    "    )\n",
    "    ans = response[\"choices\"][0][\"text\"].strip(\" \\n\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "id": "8260a739",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Gianmarco Tamberi and Mutaz Essa Barshim emerged as joint winners of the event following a tie between both of them as they cleared 2.37m. Both Tamberi and Barshim agreed to share the gold medal.'"
      ]
     },
     "execution_count": 171,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "complete(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81a73c59",
   "metadata": {},
   "source": [
    "&emsp;&emsp;试试`ChatCompletion`（ChatGPT）接口："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "id": "5c73a366",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask(content):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-3.5-turbo\", \n",
    "        messages=[{\"role\": \"user\", \"content\": content}]\n",
    "    )\n",
    "\n",
    "    ans = response.get(\"choices\")[0].get(\"message\").get(\"content\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "id": "df1a9b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans = ask(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "620f32ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"Gianmarco Tamberi and Mutaz Essa Barshim shared the gold medal in the men's high jump event at the 2020 Summer Olympics.\""
      ]
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ans"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf255281",
   "metadata": {},
   "source": [
    "&emsp;&emsp;再看几个例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "id": "121054d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected 1 document sections:\n",
      "Concerns and controversies at the 2020 Summer OlympicsSummary\n",
      "\n",
      "Q: Why was the 2020 Summer Olympics originally postponed?\n",
      "A: The 2020 Summer Olympics were originally postponed due to the COVID-19 pandemic.\n"
     ]
    }
   ],
   "source": [
    "query = \"Why was the 2020 Summer Olympics originally postponed?\"\n",
    "prompt = construct_prompt(query)\n",
    "answer = complete(prompt)\n",
    "\n",
    "print(f\"\\nQ: {query}\\nA: {answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "id": "912ffb89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected 2 document sections:\n",
      "2020 Summer Olympics medal tableSummary\n",
      "List of 2020 Summer Olympics medal winnersSummary\n",
      "\n",
      "Q: In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\n",
      "A: The United States won the most medals overall, with 113, and the most gold medals, with 39.\n"
     ]
    }
   ],
   "source": [
    "query = \"In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\"\n",
    "prompt = construct_prompt(query)\n",
    "answer = complete(prompt)\n",
    "\n",
    "print(f\"\\nQ: {query}\\nA: {answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "id": "ba7d008e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Q: In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\n",
      "A: The country that won the most medals at the 2020 Summer Olympics was the United States, with 113 medals, including 39 gold medals.\n"
     ]
    }
   ],
   "source": [
    "# ChatGPT\n",
    "answer = ask(prompt)\n",
    "\n",
    "print(f\"\\nQ: {query}\\nA: {answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "id": "59d245fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected 3 document sections:\n",
      "Sport climbing at the 2020 Summer Olympics – Men's combinedRoute-setting\n",
      "Ski mountaineering at the 2020 Winter Youth Olympics – Boys' individualSummary\n",
      "Ski mountaineering at the 2020 Winter Youth Olympics – Girls' individualSummary\n",
      "\n",
      "Q: What is the tallest mountain in the world?\n",
      "A: I don't know.\n"
     ]
    }
   ],
   "source": [
    "query = \"What is the tallest mountain in the world?\"\n",
    "prompt = construct_prompt(query)\n",
    "answer = complete(prompt)\n",
    "\n",
    "print(f\"\\nQ: {query}\\nA: {answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "id": "ebb1568e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Q: What is the tallest mountain in the world?\n",
      "A: I don't know.\n"
     ]
    }
   ],
   "source": [
    "# ChatGPT\n",
    "answer = ask(prompt)\n",
    "\n",
    "print(f\"\\nQ: {query}\\nA: {answer}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff8b2950",
   "metadata": {},
   "source": [
    "&emsp;&emsp;Very Nice, isn't that!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e701b907",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 分类/实体微调"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b0f94d0",
   "metadata": {},
   "source": [
    "&emsp;&emsp;在前面\"相关API\"一节，我们已经介绍了各种分类和实体提取的用法。这里会给大家介绍更具体、常见的任务。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77b2c281",
   "metadata": {},
   "source": [
    "&emsp;&emsp;首先是主题分类，简单来说就是给定文本，判断属于哪一类主题。\n",
    "\n",
    "&emsp;&emsp;我们找一个新闻主题分类的数据集看看，数据集取自：[CLUEbenchmark/CLUE](https://github.com/CLUEbenchmark/CLUE)，共15个类别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "efb2f03f",
   "metadata": {},
   "outputs": [
    {
     "ename": "UnicodeDecodeError",
     "evalue": "'gbk' codec can't decode byte 0x80 in position 68: illegal multibyte sequence",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mUnicodeDecodeError\u001b[0m                        Traceback (most recent call last)",
      "Input \u001b[1;32mIn [2]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\n\u001b[0;32m      2\u001b[0m lines \u001b[38;5;241m=\u001b[39m pnlp\u001b[38;5;241m.\u001b[39mread_file_to_list_dict(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/tnews.json\u001b[39m\u001b[38;5;124m\"\u001b[39m, encoding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m      3\u001b[0m \u001b[38;5;28mlen\u001b[39m(lines)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\pnlp\\__init__.py:9\u001b[0m, in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mptxt\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Regex, Text, Length\n\u001b[0;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpnorm\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NumNorm\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpenh\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TokenLevelSampler, SentenceLevelSampler\n\u001b[0;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mptrans\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m pick_entity_from_bio_labels, generate_uuid\n\u001b[0;32m     11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpmag\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MagicDict\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\pnlp\\penh.py:9\u001b[0m, in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpcut\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m cut_zhchar, cut_part, psent, psubsent\n\u001b[0;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mptxt\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Regex\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpnlp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mstopwords\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m chinese_stopwords, english_stopwords\n\u001b[0;32m     11\u001b[0m reg \u001b[38;5;241m=\u001b[39m Regex()\n\u001b[0;32m     12\u001b[0m STOPWORDS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(english_stopwords \u001b[38;5;241m|\u001b[39m chinese_stopwords)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\pnlp\\stopwords\\__init__.py:11\u001b[0m, in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      7\u001b[0m chinese_stopwords_file \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstopwords/chinese_stopwords.txt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m      8\u001b[0m english_stopwords_file \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstopwords/english_stopwords.txt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 11\u001b[0m chinese_stopwords \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m(\u001b[43mread_lines\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchinese_stopwords_file\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m     12\u001b[0m english_stopwords \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m(read_lines(english_stopwords_file))\n\u001b[0;32m     15\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mStopWords\u001b[39;00m:\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\pnlp\\piop.py:154\u001b[0m, in \u001b[0;36mread_lines\u001b[1;34m(fpath, strip, count, **kwargs)\u001b[0m\n\u001b[0;32m    152\u001b[0m i \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m    153\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(fpath, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m--> 154\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m f:\n\u001b[0;32m    155\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m count \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m i \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m count:\n\u001b[0;32m    156\u001b[0m             \u001b[38;5;28;01mbreak\u001b[39;00m\n",
      "\u001b[1;31mUnicodeDecodeError\u001b[0m: 'gbk' codec can't decode byte 0x80 in position 68: illegal multibyte sequence"
     ]
    }
   ],
   "source": [
    "import pnlp\n",
    "lines = pnlp.read_file_to_list_dict(\"data/tnews.json\", encoding='utf-8')\n",
    "len(lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "077ff261",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "ct = Counter([v[\"label_desc\"] for v in lines])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ffe20844",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('news_tech', 1089),\n",
       " ('news_finance', 956),\n",
       " ('news_entertainment', 910),\n",
       " ('news_world', 905),\n",
       " ('news_car', 791),\n",
       " ('news_sports', 767),\n",
       " ('news_culture', 736),\n",
       " ('news_military', 716),\n",
       " ('news_travel', 693),\n",
       " ('news_game', 659),\n",
       " ('news_edu', 646),\n",
       " ('news_agriculture', 494),\n",
       " ('news_house', 378),\n",
       " ('news_story', 215),\n",
       " ('news_stock', 45)]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ct.most_common()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d62346b9",
   "metadata": {},
   "source": [
    "&emsp;&emsp;`stock`这个类别太少了，我们先给它去掉："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "91013a1c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9955"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lines = [v for v in lines if v[\"label_desc\"] != \"news_stock\"]\n",
    "len(lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "39357e51",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prompt(text):\n",
    "    prompt = f\"\"\"对给定文本进行分类，类别包括：科技、金融、娱乐、世界、汽车、文化、军事、旅游、游戏、教育、农业、房产、社会、股票。\n",
    "\n",
    "给定文本：\n",
    "{text}\n",
    "类别：\n",
    "\"\"\"\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1d86eef9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'label': '102',\n",
       " 'label_desc': 'news_entertainment',\n",
       " 'sentence': '江疏影甜甜圈自拍，迷之角度竟这么好看，美吸引一切事物',\n",
       " 'keywords': '江疏影,美少女,经纪人,甜甜圈'}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lines[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "13b4196e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "对给定文本进行分类，类别包括：科技、金融、娱乐、世界、汽车、文化、军事、旅游、游戏、教育、农业、房产、社会、股票。\n",
      "\n",
      "给定文本：\n",
      "江疏影甜甜圈自拍，迷之角度竟这么好看，美吸引一切事物\n",
      "类别：\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = get_prompt(lines[0][\"sentence\"])\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e83393be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "\n",
    "openai.api_key = os.environ.get(\"OPENAI_API_KEY\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4516aca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def complete(prompt):\n",
    "    response = openai.Completion.create(\n",
    "        prompt=prompt,\n",
    "        temperature=0,\n",
    "        max_tokens=10,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0,\n",
    "        model=\"text-davinci-003\"\n",
    "    )\n",
    "    ans = response[\"choices\"][0][\"text\"].strip(\" \\n\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a52c2e14",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'娱乐'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "complete(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0b8b02f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask(content):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-3.5-turbo\", \n",
    "        messages=[{\"role\": \"user\", \"content\": content}],\n",
    "        temperature=0,\n",
    "        max_tokens=10,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0,\n",
    "    )\n",
    "\n",
    "    ans = response.get(\"choices\")[0].get(\"message\").get(\"content\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2a0f37d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'娱乐'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ask(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c2ca268",
   "metadata": {},
   "source": [
    "&emsp;&emsp;再试几个其他类别的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "3b88a272",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'label': '110',\n",
       " 'label_desc': 'news_military',\n",
       " 'sentence': '以色列大规模空袭开始！伊朗多个军事目标遭遇打击，誓言对等反击',\n",
       " 'keywords': '伊朗,圣城军,叙利亚,以色列国防军,以色列'}"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lines[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "bec207e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "对给定文本进行分类，类别包括：科技、金融、娱乐、世界、汽车、文化、军事、旅游、游戏、教育、农业、房产、社会、股票。\n",
      "\n",
      "给定文本：\n",
      "以色列大规模空袭开始！伊朗多个军事目标遭遇打击，誓言对等反击\n",
      "类别：\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = get_prompt(lines[1][\"sentence\"])\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 210,
   "id": "66e8d5d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'军事'"
      ]
     },
     "execution_count": 210,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ask(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "id": "e6fdf09e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'军事'"
      ]
     },
     "execution_count": 211,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "complete(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "c09579c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'label': '104',\n",
       " 'label_desc': 'news_finance',\n",
       " 'sentence': '出栏一头猪亏损300元，究竟谁能笑到最后！',\n",
       " 'keywords': '商品猪,养猪,猪价,仔猪,饲料'}"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lines[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 220,
   "id": "54b4308e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "对给定文本进行分类，类别包括：科技、金融、娱乐、世界、汽车、文化、军事、旅游、游戏、教育、农业、房产、社会、股票。\n",
      "\n",
      "给定文本：\n",
      "出栏一头猪亏损300元，究竟谁能笑到最后！\n",
      "类别：\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = get_prompt(lines[2][\"sentence\"])\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "id": "322e0bba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'农业'"
      ]
     },
     "execution_count": 221,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ask(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 222,
   "id": "4f27966f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'社会'"
      ]
     },
     "execution_count": 222,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "complete(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "666931fa",
   "metadata": {},
   "source": [
    "&emsp;&emsp;这个有点迷糊，不过「农业」这个类别感觉也没问题。当然，遇到有错误的情况，我们起码还有两种手段来解决：\n",
    "\n",
    "- Few-Shot，可以每次随机从数据集里抽几条出来作为Prompt的一部分。\n",
    "- Fine-Tuning，把我们自己的数据集按指定格式准备好，提交给API，让它帮我们微调一个属于我们自己的模型，它在我们自己的数据集上学习过。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67c32ee7",
   "metadata": {},
   "source": [
    "&emsp;&emsp;Few-Shot最关键的是如何找到这个「Few」，换句话说，我们拿什么Case给模型当做参考样本。对于类别比较多的多分类（实际工作中，成百上千中Label是很常见的），Few-Shot即使每个Label一个例子，这上下文长度也不得了。不太现实。这时候其实Few-Shot有点不太方便了。当然，如果我们非要用也不是不行，还是最常用的策略：先召回几个相似句，然后把相似句的内容和类别作为Few-Shot的例子，让接口来预测给定句子的类别。\n",
    "\n",
    "&emsp;&emsp;不过，我们还可以使用微调（Fine-Tuning）方法在自己的数据集上对模型进行微调，简单来说就是让模型「熟悉」我们独特的数据，进而让其具备在类似数据集上识别出类别的能力。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6e19030",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来，就让我们看看具体怎么做，一般包括三个主要步骤：\n",
    "\n",
    "- 准备数据：按接口要求的格式把数据准备好，这里的数据就是我们自己的数据集，至少包含一段文本和一个类别。\n",
    "- 微调：使用微调接口将刚刚的数据传递过去，由服务器自动完成微调，微调完成后可以得到一个新的model_id。注意，这个model_id只属于你自己，不要将它公开给其他人。\n",
    "- 使用新的模型进行推理：嗯，这个很简单，把原来接口里的`model`参数内容换成我们的model_id即可。\n",
    "\n",
    "&emsp;&emsp;咱们接下来就来调调这个多分类模型，我们只取后500条作为训练集（为了快速和省钱……）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3399cdb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_lines = lines[-500:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "094c9b49",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5b5cd0a0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(500, 4)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.DataFrame(train_lines)\n",
    "train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2e3e4e76",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>label_desc</th>\n",
       "      <th>sentence</th>\n",
       "      <th>keywords</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>103</td>\n",
       "      <td>news_sports</td>\n",
       "      <td>为什么斯凯奇与阿迪达斯脚感很相似，价格却差了近一倍？</td>\n",
       "      <td>达斯勒,阿迪达斯,FOAM,BOOST,斯凯奇</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>100</td>\n",
       "      <td>news_story</td>\n",
       "      <td>女儿日渐消瘦，父母发现有怪物，每天吃女儿一遍</td>\n",
       "      <td>大将军,怪物</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>104</td>\n",
       "      <td>news_finance</td>\n",
       "      <td>另类逼空确认反弹，剑指3200点以上</td>\n",
       "      <td>股票,另类逼空,金融,创业板,快速放大</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>100</td>\n",
       "      <td>news_story</td>\n",
       "      <td>老公在聚会上让我向他的上司敬酒，现在老公哭了，我笑了</td>\n",
       "      <td>远走高飞</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>108</td>\n",
       "      <td>news_edu</td>\n",
       "      <td>女孩上初中之后成绩下降，如何才能提升成绩？</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  label    label_desc                    sentence                 keywords\n",
       "0   103   news_sports  为什么斯凯奇与阿迪达斯脚感很相似，价格却差了近一倍？  达斯勒,阿迪达斯,FOAM,BOOST,斯凯奇\n",
       "1   100    news_story      女儿日渐消瘦，父母发现有怪物，每天吃女儿一遍                   大将军,怪物\n",
       "2   104  news_finance          另类逼空确认反弹，剑指3200点以上      股票,另类逼空,金融,创业板,快速放大\n",
       "3   100    news_story  老公在聚会上让我向他的上司敬酒，现在老公哭了，我笑了                     远走高飞\n",
       "4   108      news_edu       女孩上初中之后成绩下降，如何才能提升成绩？                         "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "aa55824e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "news_finance          48\n",
       "news_tech             47\n",
       "news_game             46\n",
       "news_entertainment    46\n",
       "news_travel           44\n",
       "news_sports           42\n",
       "news_military         40\n",
       "news_world            38\n",
       "news_car              36\n",
       "news_culture          35\n",
       "news_edu              27\n",
       "news_agriculture      20\n",
       "news_house            19\n",
       "news_story            12\n",
       "Name: label_desc, dtype: int64"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.label_desc.value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90c70f54",
   "metadata": {},
   "source": [
    "股票数据稍微少了些，问题不大。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff1b03fc",
   "metadata": {},
   "source": [
    "**Step1：准备数据**\n",
    "\n",
    "&emsp;&emsp;要保证有两列为：prompt和completion。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8bd3a7c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prompt</th>\n",
       "      <th>completion</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>为什么斯凯奇与阿迪达斯脚感很相似，价格却差了近一倍？</td>\n",
       "      <td>news_sports</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>女儿日渐消瘦，父母发现有怪物，每天吃女儿一遍</td>\n",
       "      <td>news_story</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>另类逼空确认反弹，剑指3200点以上</td>\n",
       "      <td>news_finance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>老公在聚会上让我向他的上司敬酒，现在老公哭了，我笑了</td>\n",
       "      <td>news_story</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>女孩上初中之后成绩下降，如何才能提升成绩？</td>\n",
       "      <td>news_edu</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       prompt    completion\n",
       "0  为什么斯凯奇与阿迪达斯脚感很相似，价格却差了近一倍？   news_sports\n",
       "1      女儿日渐消瘦，父母发现有怪物，每天吃女儿一遍    news_story\n",
       "2          另类逼空确认反弹，剑指3200点以上  news_finance\n",
       "3  老公在聚会上让我向他的上司敬酒，现在老公哭了，我笑了    news_story\n",
       "4       女孩上初中之后成绩下降，如何才能提升成绩？      news_edu"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train = train[[\"sentence\", \"label_desc\"]]\n",
    "df_train.columns = [\"prompt\", \"completion\"]\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c935f5d",
   "metadata": {},
   "source": [
    "&emsp;&emsp;存起来："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f1cbf196",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train.to_json(\"dataset/tnews-finetuning.jsonl\", orient='records', lines=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8de6186e",
   "metadata": {},
   "source": [
    "&emsp;&emsp;使用openai命令行工具进行转换："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "dae946e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Analyzing...\n",
      "\n",
      "- Your file contains 500 prompt-completion pairs\n",
      "- Based on your data it seems like you're trying to fine-tune a model for classification\n",
      "- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n",
      "- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training\n",
      "- More than a third of your `prompt` column/key is uppercase. Uppercase prompts tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details\n",
      "- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty\n",
      "- All completions start with prefix `news_`. Most of the time you should only add the output data into the completion, without any prefix\n",
      "- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details\n",
      "\n",
      "Based on the analysis we will perform the following actions:\n",
      "- [Recommended] Lowercase all your data in column/key `prompt` [Y/n]: Y\n",
      "- [Recommended] Add a suffix separator ` ->` to all prompts [Y/n]: Y\n",
      "- [Recommended] Remove prefix `news_` from all completions [Y/n]: Y\n",
      "- [Recommended] Add a whitespace character to the beginning of the completion [Y/n]: Y\n",
      "- [Recommended] Would you like to split into training and validation set? [Y/n]: Y\n",
      "\n",
      "\n",
      "Your data will be written to a new JSONL file. Proceed [Y/n]: Y\n",
      "\n",
      "Wrote modified files to `dataset/tnews-finetuning_prepared_train.jsonl` and `dataset/tnews-finetuning_prepared_valid.jsonl`\n",
      "Feel free to take a look!\n",
      "\n",
      "Now use that file when fine-tuning:\n",
      "> openai api fine_tunes.create -t \"dataset/tnews-finetuning_prepared_train.jsonl\" -v \"dataset/tnews-finetuning_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 14\n",
      "\n",
      "After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string ` ->` for the model to start generating completions, rather than continuing with the prompt.\n",
      "Once your model starts training, it'll approximately take 14.33 minutes to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
     ]
    }
   ],
   "source": [
    "!openai tools fine_tunes.prepare_data -f dataset/tnews-finetuning.jsonl -q"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58a1acec",
   "metadata": {},
   "source": [
    "&emsp;&emsp;看一下处理成什么样子了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "74d64c60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"prompt\":\"cf生存特训：火箭弹狂野复仇，为兄弟报仇就要不死不休 ->\",\"completion\":\" game\"}\n",
      "{\"prompt\":\"哈尔滨 东北抗日联军博物馆 ->\",\"completion\":\" culture\"}\n",
      "{\"prompt\":\"中国股市中，庄家为何如此猖獗？一文告诉你真相 ->\",\"completion\":\" finance\"}\n",
      "{\"prompt\":\"天府锦绣又重来 ->\",\"completion\":\" agriculture\"}\n",
      "{\"prompt\":\"生活，游戏，电影中有哪些词汇稍加修改便可以成为一个非常霸气的名字？ ->\",\"completion\":\" game\"}\n",
      "{\"prompt\":\"法庭上，生父要争夺孩子抚养权，小男孩的发言让生父当场哑口无言 ->\",\"completion\":\" entertainment\"}\n",
      "{\"prompt\":\"如何才能选到好的深圳大数据培训机构？ ->\",\"completion\":\" edu\"}\n",
      "{\"prompt\":\"有哪些娱乐圈里面的明星追星？ ->\",\"completion\":\" entertainment\"}\n",
      "{\"prompt\":\"东坞原生态野生茶 ->\",\"completion\":\" culture\"}\n",
      "{\"prompt\":\"亚冠：恒大不胜早有预示，全北失利命中注定 ->\",\"completion\":\" sports\"}\n"
     ]
    }
   ],
   "source": [
    "!head dataset/tnews-finetuning_prepared_train.jsonl"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13caceb9",
   "metadata": {},
   "source": [
    "&emsp;&emsp;最好再验证一下每个数据集的类型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ac7b23a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pnlp.read_file_to_list_dict(\"./dataset/tnews-finetuning_prepared_train.jsonl\")\n",
    "valid = pnlp.read_file_to_list_dict(\"./dataset/tnews-finetuning_prepared_valid.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "b5383c59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "14"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(Counter([v[\"completion\"] for v in train]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "83eaba49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "14"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(Counter([v[\"completion\"] for v in valid]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1b3ea7e",
   "metadata": {},
   "source": [
    "&emsp;&emsp;好的，最明显的是给我们每一个prompt后面加了个标记，除此之外还有（请注意看上面的日志）：\n",
    "\n",
    "- 小写\n",
    "- 去除标签`news_`前缀\n",
    "- 在completion前面加空格\n",
    "- 切分为训练和验证集\n",
    "\n",
    "&emsp;&emsp;这些都是常见的、推荐的预处理，我们就按这样。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "784fc33c",
   "metadata": {},
   "source": [
    "**Step2：微调**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e9b6e04a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2da1ba12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'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'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.environ.setdefault(\"OPENAI_API_KEY\", \"填入专属的API key\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "f8e4e194",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Upload progress: 100%|████████████████████| 41.2k/41.2k [00:00<00:00, 15.6Mit/s]\n",
      "Uploaded file from ./dataset/tnews-finetuning_prepared_train.jsonl: file-DKjBKHqWFJwo7O8MNZGOcj3F\n",
      "Upload progress: 100%|████████████████████| 10.5k/10.5k [00:00<00:00, 7.85Mit/s]\n",
      "Uploaded file from ./dataset/tnews-finetuning_prepared_valid.jsonl: file-j088k3GWqGeqY0o2DDAfWfPh\n",
      "Created fine-tune: ft-QOkrWkHU0aleR6f5IQw1UpVL\n",
      "Streaming events until fine-tuning is complete...\n",
      "\n",
      "(Ctrl-C will interrupt the stream, but not cancel the fine-tune)\n",
      "[2023-04-04 21:16:33] Created fine-tune: ft-QOkrWkHU0aleR6f5IQw1UpVL\n",
      "\n",
      "Stream interrupted (client disconnected).\n",
      "To resume the stream, run:\n",
      "\n",
      "  openai api fine_tunes.follow -i ft-QOkrWkHU0aleR6f5IQw1UpVL\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!openai api fine_tunes.create \\\n",
    "    -t \"./dataset/tnews-finetuning_prepared_train.jsonl\" \\\n",
    "    -v \"./dataset/tnews-finetuning_prepared_valid.jsonl\" \\\n",
    "    --compute_classification_metrics --classification_n_classes 14 \\\n",
    "    -m davinci\\\n",
    "    --no_check_if_files_exist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49dcff4c",
   "metadata": {},
   "source": [
    "&emsp;&emsp;其中，`-t`和`-v`分别指定训练集和验证集，接下来那行用来计算指标，`-m`指定要微调的模型。最后一行是检查文件是否存在，如果之前传过文件的话，这里可以复用。为了便于演示，我们这里不检查。可以微调的模型和价格参见：[Pricing](https://openai.com/pricing)。\n",
    "\n",
    "&emsp;&emsp;另外，值得一提的是：只能微调`Completion`接口，`ChatCompletion`不支持微调。也就是说InstructGPT的几个模型是可以微调的，但是ChatGPT不能微调，可参阅：https://platform.openai.com/docs/guides/chat/is-fine-tuning-available-for-gpt-3-5-turbo\n",
    "\n",
    "&emsp;&emsp;我们这里选davinci，也是咱们之前一直用的。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a1d2885",
   "metadata": {},
   "source": [
    "&emsp;&emsp;可以看到，上面跑一下就断掉了，这个是正常的，我们可以通过另一个API去查看任务的进度。注意，这里的ID是上面日志打出来的ID，每次执行都会变。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "63f6cfcf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"created_at\": 1680614193,\n",
      "  \"events\": [\n",
      "    {\n",
      "      \"created_at\": 1680614193,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Created fine-tune: ft-QOkrWkHU0aleR6f5IQw1UpVL\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    }\n",
      "  ],\n",
      "  \"fine_tuned_model\": null,\n",
      "  \"hyperparams\": {\n",
      "    \"batch_size\": null,\n",
      "    \"classification_n_classes\": 14,\n",
      "    \"compute_classification_metrics\": true,\n",
      "    \"learning_rate_multiplier\": null,\n",
      "    \"n_epochs\": 4,\n",
      "    \"prompt_loss_weight\": 0.01\n",
      "  },\n",
      "  \"id\": \"ft-QOkrWkHU0aleR6f5IQw1UpVL\",\n",
      "  \"model\": \"davinci\",\n",
      "  \"object\": \"fine-tune\",\n",
      "  \"organization_id\": \"org-bKXddeZffpMS2CUNCCXsW7m5\",\n",
      "  \"result_files\": [],\n",
      "  \"status\": \"pending\",\n",
      "  \"training_files\": [\n",
      "    {\n",
      "      \"bytes\": 41212,\n",
      "      \"created_at\": 1680614191,\n",
      "      \"filename\": \"./dataset/tnews-finetuning_prepared_train.jsonl\",\n",
      "      \"id\": \"file-DKjBKHqWFJwo7O8MNZGOcj3F\",\n",
      "      \"object\": \"file\",\n",
      "      \"purpose\": \"fine-tune\",\n",
      "      \"status\": \"processed\",\n",
      "      \"status_details\": null\n",
      "    }\n",
      "  ],\n",
      "  \"updated_at\": 1680614193,\n",
      "  \"validation_files\": [\n",
      "    {\n",
      "      \"bytes\": 10507,\n",
      "      \"created_at\": 1680614193,\n",
      "      \"filename\": \"./dataset/tnews-finetuning_prepared_valid.jsonl\",\n",
      "      \"id\": \"file-j088k3GWqGeqY0o2DDAfWfPh\",\n",
      "      \"object\": \"file\",\n",
      "      \"purpose\": \"fine-tune\",\n",
      "      \"status\": \"processed\",\n",
      "      \"status_details\": null\n",
      "    }\n",
      "  ]\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "!openai api fine_tunes.get -i ft-QOkrWkHU0aleR6f5IQw1UpVL"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9839b8dd",
   "metadata": {},
   "source": [
    "&emsp;&emsp;或者用它刚刚给的提示。注意换ID！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "37fcbcf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-04-04 14:32:14] Created fine-tune: ft-5LDv5IiFqPvLob3KkThWLTUG\n",
      "\n",
      "Stream interrupted (client disconnected).\n",
      "To resume the stream, run:\n",
      "\n",
      "  openai api fine_tunes.follow -i ft-5LDv5IiFqPvLob3KkThWLTUG\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!openai api fine_tunes.follow -i ft-QOkrWkHU0aleR6f5IQw1UpVL"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76513f44",
   "metadata": {},
   "source": [
    "&emsp;&emsp;注意，这个是`follow`，刚刚上面那个是`get`。大家可以通过`openai api --help`查看更多："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "fd3da9a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "usage: openai api [-h]\n",
      "                  {engines.list,engines.get,engines.update,engines.generate,chat_completions.create,completions.create,deployments.list,deployments.get,deployments.delete,deployments.create,models.list,models.get,models.delete,files.create,files.get,files.delete,files.list,fine_tunes.list,fine_tunes.create,fine_tunes.get,fine_tunes.results,fine_tunes.events,fine_tunes.follow,fine_tunes.cancel,fine_tunes.delete,image.create,image.create_edit,image.create_variation,audio.transcribe,audio.translate}\n",
      "                  ...\n",
      "\n",
      "positional arguments:\n",
      "  {engines.list,engines.get,engines.update,engines.generate,chat_completions.create,completions.create,deployments.list,deployments.get,deployments.delete,deployments.create,models.list,models.get,models.delete,files.create,files.get,files.delete,files.list,fine_tunes.list,fine_tunes.create,fine_tunes.get,fine_tunes.results,fine_tunes.events,fine_tunes.follow,fine_tunes.cancel,fine_tunes.delete,image.create,image.create_edit,image.create_variation,audio.transcribe,audio.translate}\n",
      "                        All API subcommands\n",
      "\n",
      "optional arguments:\n",
      "  -h, --help            show this help message and exit\n"
     ]
    }
   ],
   "source": [
    "!openai api --help"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f55e5517",
   "metadata": {},
   "source": [
    "&emsp;&emsp;建议大家过段时间`get`一下进度就好，不需要一直`follow`。这里可能要等一段时间，等排队完成后进入训练阶段就很快了。\n",
    "\n",
    "&emsp;&emsp;主要看`status`是什么状态，依然注意要换ID。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "71f561b9",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"created_at\": 1680614193,\n",
      "  \"events\": [\n",
      "    {\n",
      "      \"created_at\": 1680614193,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Created fine-tune: ft-QOkrWkHU0aleR6f5IQw1UpVL\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680614845,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune costs $2.33\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680614846,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune enqueued\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680617657,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 31\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680617785,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 30\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680617805,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 29\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680617809,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 28\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680617918,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 27\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680617928,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 26\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618038,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 25\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618050,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 24\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618087,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 23\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618096,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 22\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618207,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 21\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618256,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 20\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618268,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 19\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618336,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 18\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618372,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 17\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618445,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 16\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618488,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 15\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618582,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 14\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618612,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 13\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618652,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 12\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618693,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 11\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618717,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 10\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618759,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 9\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618790,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 8\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618841,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 7\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618886,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 6\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618899,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 5\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618928,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 4\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680618979,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 3\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619057,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 2\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619065,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 1\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619127,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune is in the queue. Queue number: 0\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619227,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune started\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619443,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Completed epoch 1/4\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619573,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Completed epoch 2/4\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619701,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Completed epoch 3/4\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619829,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Completed epoch 4/4\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619890,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Uploaded model: davinci:ft-personal-2023-04-04-14-51-29\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619891,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Uploaded result file: file-xvIjsDl5aFEtXWOVY6nZyLcD\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    },\n",
      "    {\n",
      "      \"created_at\": 1680619891,\n",
      "      \"level\": \"info\",\n",
      "      \"message\": \"Fine-tune succeeded\",\n",
      "      \"object\": \"fine-tune-event\"\n",
      "    }\n",
      "  ],\n",
      "  \"fine_tuned_model\": \"davinci:ft-personal-2023-04-04-14-51-29\",\n",
      "  \"hyperparams\": {\n",
      "    \"batch_size\": 1,\n",
      "    \"classification_n_classes\": 14,\n",
      "    \"compute_classification_metrics\": true,\n",
      "    \"learning_rate_multiplier\": 0.1,\n",
      "    \"n_epochs\": 4,\n",
      "    \"prompt_loss_weight\": 0.01\n",
      "  },\n",
      "  \"id\": \"ft-QOkrWkHU0aleR6f5IQw1UpVL\",\n",
      "  \"model\": \"davinci\",\n",
      "  \"object\": \"fine-tune\",\n",
      "  \"organization_id\": \"org-bKXddeZffpMS2CUNCCXsW7m5\",\n",
      "  \"result_files\": [\n",
      "    {\n",
      "      \"bytes\": 82416,\n",
      "      \"created_at\": 1680619891,\n",
      "      \"filename\": \"compiled_results.csv\",\n",
      "      \"id\": \"file-xvIjsDl5aFEtXWOVY6nZyLcD\",\n",
      "      \"object\": \"file\",\n",
      "      \"purpose\": \"fine-tune-results\",\n",
      "      \"status\": \"processed\",\n",
      "      \"status_details\": null\n",
      "    }\n",
      "  ],\n",
      "  \"status\": \"succeeded\",\n",
      "  \"training_files\": [\n",
      "    {\n",
      "      \"bytes\": 41212,\n",
      "      \"created_at\": 1680614191,\n",
      "      \"filename\": \"./dataset/tnews-finetuning_prepared_train.jsonl\",\n",
      "      \"id\": \"file-DKjBKHqWFJwo7O8MNZGOcj3F\",\n",
      "      \"object\": \"file\",\n",
      "      \"purpose\": \"fine-tune\",\n",
      "      \"status\": \"processed\",\n",
      "      \"status_details\": null\n",
      "    }\n",
      "  ],\n",
      "  \"updated_at\": 1680619892,\n",
      "  \"validation_files\": [\n",
      "    {\n",
      "      \"bytes\": 10507,\n",
      "      \"created_at\": 1680614193,\n",
      "      \"filename\": \"./dataset/tnews-finetuning_prepared_valid.jsonl\",\n",
      "      \"id\": \"file-j088k3GWqGeqY0o2DDAfWfPh\",\n",
      "      \"object\": \"file\",\n",
      "      \"purpose\": \"fine-tune\",\n",
      "      \"status\": \"processed\",\n",
      "      \"status_details\": null\n",
      "    }\n",
      "  ]\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "!openai api fine_tunes.get -i ft-QOkrWkHU0aleR6f5IQw1UpVL"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7095e389",
   "metadata": {},
   "source": [
    "&emsp;&emsp;过了很长一段时间，终于成功了。微调结束后，我们还可以通过下面的命令查看结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5059e9dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -i 就是上面微调的模型id，就是`id`字段\n",
    "!openai api fine_tunes.results -i ft-QOkrWkHU0aleR6f5IQw1UpVL > metric.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a09093a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>step</th>\n",
       "      <th>elapsed_tokens</th>\n",
       "      <th>elapsed_examples</th>\n",
       "      <th>training_loss</th>\n",
       "      <th>training_sequence_accuracy</th>\n",
       "      <th>training_token_accuracy</th>\n",
       "      <th>validation_loss</th>\n",
       "      <th>validation_sequence_accuracy</th>\n",
       "      <th>validation_token_accuracy</th>\n",
       "      <th>classification/accuracy</th>\n",
       "      <th>classification/weighted_f1_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1601</th>\n",
       "      <td>1602</td>\n",
       "      <td>83226</td>\n",
       "      <td>1602</td>\n",
       "      <td>0.008739</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.63</td>\n",
       "      <td>0.619592</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      step  elapsed_tokens  elapsed_examples  training_loss  \\\n",
       "1601  1602           83226              1602       0.008739   \n",
       "\n",
       "      training_sequence_accuracy  training_token_accuracy  validation_loss  \\\n",
       "1601                         1.0                      1.0              NaN   \n",
       "\n",
       "      validation_sequence_accuracy  validation_token_accuracy  \\\n",
       "1601                           NaN                        NaN   \n",
       "\n",
       "      classification/accuracy  classification/weighted_f1_score  \n",
       "1601                     0.63                          0.619592  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric = pd.read_csv('metric.csv')\n",
    "metric[metric['classification/accuracy'].notnull()].tail(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "58f97545",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD2CAYAAAA6eVf+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgXElEQVR4nO3deXxV9Z3/8dcnCUnYEgMkgKwqAooYLJG64CBYUato1VqxtVo7HdpOO6JItXbs2HZmWltX1NZfafsb+7NWxIWxxUFl0aKjiKGQCEhkFYIsCZSELSHL5/dHLjSEe5MLWc5d3s/HIw/PPfd7rp+v5r7Pud/zvd+YuyMiIskhJegCRESk4yj0RUSSiEJfRCSJKPRFRJKIQl9EJImkBV1Ac3r16uWDBw8OugwRkbiybNmycnfPDfdcTIf+4MGDKSwsDLoMEZG4YmafRHpOwzsiIklEoS8ikkQU+iIiSUShLyKSRBT6IiJJRKEvIpJEFPoiIkkkpufpi0j72lS+n3krt5PduRN9szPpk51J3+xMsjt3wsyCLk/agUJfJAmV76vmiYVrefb9zdTWH/s3NTLSUhqdBDrTOyvzqJNCn6xMenbLIDVFJ4Z4o9AXSSL7q2v57dsbmbl4PVW19Uw+dwDfGT8EB7ZXVLG9ooptFQcbtisbHi/duJsdlVXHnBzSUozeWQ0ngj5ZjU4IjR7ndc8kPU2jyLFEoS+SBGrq6pn1wRZmLFhL+b5qrjirD9MvG8Zpud2OtOl3UueIx9fXO7v2HzpyUthRWcW2IyeJKj7aVsmiNTs5WFN31HFm0KtbxlEnhaM/NXSmT1YmndNT263vcjSFvkgCc3fmrdzOg6+XsLF8P2MG92DmLaP5zMCc43qdlBQjt3sGud0zGNk/O+K/q/JgLdsr//5pYVtF1ZETxOZdB3h/wy4qq2qPObbxPYWjPzV0PvI4KzNN9xnagEJfJEG9v2EXP5u3hhVb9jC0dzd+d2sBE4bntVtwmhnZXTqR3aUTw/p0j9juwKHaRkNJfx9GOnyCWLm1kvJ91ccc1yU9NeJJ4fAJo0eXdFJ0n6FZCn2RBFOyfS+/eG0NC9fspE9WJr+4/myuH90/Zm66dklP49TcbpzaaGipqUO19eyorDpmGGlH6FPEkvW72LG3mrom9xnSU1PIy8podELIoE9256M+ReR1zyAtNXnvMyj0RRLEp3sO8uj8j3npr6V0zUjjnsuHc9uFg8nsFH/j5elpKQzo0YUBPbpEbFNX7+zaV822JsNI2ysOsq2iig9L9/BGRRXVtfVHHZdikNs948hJoW9252PuN/TOyozL/27RUOiLxLmKAzX86i/rePp/N+EO/zj2FP754iHkdE0PurR2lZpi5GVlkpeVSf6A8G3cnT0Hao4aQtpecTB036GKDWX7eXfdLvZWH3ufoUfX9KNvOmdl0jt0cjj8SaJbRvxFaPxVLCIAVNXU8f/e28Qv31xPZVUN147qx7SJQ+mfE/nqONmYGTld08npms4ZfbMitttX/ff7DA0niINHhpW2V1ZRtGUPu/YfOua4bhlpR313oc9R32Vo+ASR0yW2vuim0BeJM3X1zn8v38oj8z9m656DjBuayz2XD+fMkyOHmjSvW0YaQ/K6MSQv8n2Gqpo6dlZWHzU7qfEniLU7ytm5t4qm33VLP/xFt8YnhSPbDfcbenXgF90U+iJxwt156+Myfj5vDWu272Vkv2we/OLZXDCkV9ClJYXMTqkM7NmFgT0jf5KqraunfN+hsCeF7ZVVLN+8h+0VVRyqO/o+Q2qKkdc946hPCeOG5TJuaNg/c9sqCn2ROFC0ZQ8PzFvDext2MbBHF5646RyuHNlX0xNjTFpqypGr+Ujcnd37DzW5z/D3E0TJ9r38paSMbplpCn2RZLOpfD8PvlHCq8Xb6Nk1nR9fPYKbxgzU0gZxzMzo2S2Dnt0yGHFy+C+6AcdMR20rCn2RGFS+r5rHF67lj+9vplNqCrdPGMI//cOpdM/sFHRp0kHaa4y/xdA3s0zgRWAAUAzc4u7HnILM7G5gErAPuAZIBWYDfYDl7j7FzL4FTAe2hw670t0r2qIjIokg3IJoUz93OnndIw8XiByPaK70bwZK3f0qM5sLXAq80biBmZ0KjHD3i8zsdqA/cAFQ7O6TzGyjmZ0Zan6/uz/bhn0QiXvRLIgm0haiCf0JwEuh7UXAeJqEPnAJkGNmi4EdwBNADlAYer7xBNfvmtl0YLG7Tz3RwkUSQVstiCYSrWhCvydweAimEhgWpk0uUObuV5vZe8BYd38bwMzuAJa5+2oz6wp8CCwBNprZo+6+qfELmdkUYArAwIEDj79HInGioxdEE4HoQr8cOHyLOTv0uKlKoCS0vQHoB2Bm3wYuAiaHntsMlLt7nZmVAnnApsYv5O4zgZkABQUF7XP7WiRAJdv38vPX1rAoRhdEk8QWTegvBCbSMMQzAXg0TJtlwLTQ9hBgg5nlA1cC17j74b+s8Agw08yWAgOBta2oXSSuJNKCaBK/ogn9Z4HrzKwYKALWm9lD7j79cAN3f8/Mys3sA+Ajd19qZk8Bg4G/hD6u/ifwU+C3QDrwE3f/W9t2RyT2JOuCaBKbLMzsy5hRUFDghYWFLTcUiUFaEE2CYmbL3L0g3HP6cpZIG9OCaBLLFPoibUQLokk8UOiLtIGiLXv42byPWLJhtxZEk5im0BdpBS2IJvFGoS9yArQgmsQrhb7IcdhfXctv3t7AbxZv0IJoEpcU+iJR0IJokigU+iLN0IJokmgU+iIRaEE0SUQKfZEmtCCaJDKFvkjIp3sO8khoQbRuWhBNEpRCX5Je0wXRvqEF0SSBKfQlaWlBNElGCn1JOnX1zpzlW3nkjRI+rajSgmiSVBT6kjTCLYj20A35WhBNkopCX5KCFkQTaaDQl4S2qXw/D75ewqsfakE0EVDoS4LSgmgi4Sn0JaFoQTSR5in0JSHU1NUza+lmZixcS/m+Q1oQTSSCFkPfzDKBF4EBQDFwi4f5a+pmdjcwCdgHXANkAXOAk4BX3f37Ztar6b626YYkq/ALog3XgmgiEURzN+tmoNTd84Ec4NKmDczsVGCEu18EzAP6A3cArwL5wBVmNjTCPpETsmTDLr7wq3f552f/SqdU43e3FvD8N89T4Is0I5rhnQnAS6HtRcB44I0mbS4BcsxsMbADeCJ03L+4e72Z/SV0XLh9Hzd+ITObAkwBGDhw4Al1ShKbFkQTOXHRhH5PoCK0XQkMC9MmFyhz96vN7D1gbJjjekTYdxR3nwnMBCgoKDhmGEmSlxZEE2m9aEK/HMgObWeHHjdVCZSEtjcA/cIc90mEfSLNqjhQw6/eWsd/vbsJtCCaSKtEE/oLgYk0DPFMAB4N02YZMC20PYSG4F8ITDSz5cA4YAYNN4Ob7hMJSwuiibS9aEL/WeA6MysGioD1ZvaQu08/3MDd3zOzcjP7APjI3Zea2QYaZup8Bfizu68zs8eb7mvzHknc04JoIu3Hwsy+jBkFBQVeWFgYdBnSQdydt0rK+Plrf18Q7d4rhmtBNJHjZGbL3L0g3HP6cpbEBC2IJtIxFPoSKC2IJtKxFPoSiLK91TyxSAuiiXQ0hb50KC2IJhIshb50mNdWbuO+/15F+b5qLYgmEhCFvnSIj3fsZeqsFZzeuxszbxmt9XFEAqLQl3ZXXVvH7c8tp1tGGv/1tTHkds8IuiSRpKXQl3b34GslrNm+l//7tQIFvkjANC9O2tXij8v47TsbueX8QUwY3jvockSSnkJf2s3u/Ye464UiTs/rxg8+f0bQ5YgIGt6RduLu3PNSMRUHavj9bWO0/LFIjNCVvrSL55ZuYf7qHdx9+TAtlCYSQxT60ubW7dzHT+au4qLTe/H1C08JuhwRaUShL23qUG09U2ctp3OnVB66IV8LponEGI3pS5t6eH4Jqz6tZOZXR9M7S0sriMQaXelLm3l3XTkzF2/gy58dyMQRfYIuR0TCUOhLm9hz4BDTZhdxSq+u3HelpmeKxCqFvrSau3Pvyx+ya381j08+hy7pGjUUiVUKfWm1FwpLmbdyO3dNHMZZ/bKDLkdEmtFs6JtZppnNNbMiM3vGzI6ZimFml5tZqZm9E/oZZmYXN3q8xcxuDdeu/bolHWVj+X5+9OdVXHBaT6ZcdGrQ5YhIC1q60r8ZKHX3fCAHuDRCu6fcfWzop8Td3zr8GCgGlodr1zZdkKDU1NVzx6zldEpN4eEvaXqmSDxoKfQnAPND24uA8RHaXW9mS83spcafBsysCzDE3Yubayfx6bEFH1NUWsED142kb3bnoMsRkSi0FPo9gYrQdiXQI0yb9cAP3X0M0BcY1+i5S4GFUbQ7wsymmFmhmRWWlZVF1wvpcO9v2MWv3lrPlwr6c8XIvkGXIyJRain0y4HDd+ayQ4+b2g0sCG1vAvIaPTcJmBtFuyPcfaa7F7h7QW5ubgvlSRAqDtZw5/MrGNSjC/dPGhF0OSJyHFoK/YXAxND2BODNMG2mAZPNLAU4C1gJEBq+GU/DsFDEdhJf3J1/nfMhO/dWM2PyOXTN0PRMkXjSUug/C/Qzs2IartTXm9lDTdo8CdwGvA/McffVof3nAqvcvaqFdhJHXv7rVuYWb+POS4eSP+CkoMsRkeNk7h50DREVFBR4YWFh0GVIyCe79vP5GW8z4uRsnptyHqmarSMSk8xsmbsXhHtOX86SqNTW1XPH8ytISTEenTxKgS8SpzQgK1F5YtE6lm/ewxM3nUO/kzQ9UyRe6UpfWrTsk908sWgt132mH5PyTw66HBFpBYW+NKuyqoaps1bQL6czP75a0zNF4p2Gd6RZ97+yim0VVcz+5vl0z+wUdDki0kq60peIXlmxlTnLt3L7hNMZPSgn6HJEpA0o9CWsLbsPcN+clYwelMN3xp8WdDki0kYU+nKM2rp6ps1egQOP3TiKtFT9mogkCo3pyzGeems9H2z6G4/emM+AHl2CLkdE2pAu4eQoyzf/jccWruWaUSdz7Tn9gy5HRNqYQl+O2Fddyx3Pr6BPViY/ueasoMsRkXag4R054kd/WsWW3QeYNeV8sjtreqZIItKVvgAwt/hTXlxWynfGD2HMKeH+Vo6IJAKFvvDpnoP84OUPGTXgJG6/5PSgyxGRdqTQT3J19c6dz6+grt6ZMXkUnTQ9UyShaUw/yc1cvIH3N+7mwS+ezaCeXYMuR0TamS7rklhx6R4efqOEK0f25YujNT1TJBko9JPUgUO1TJ21gtzuGfz02pE0/EljEUl0Gt5JUv8+dzWbdu3nj984j+wump4pkiyavdI3s0wzm2tmRWb2jIW5HDSzy82s1MzeCf0Mi7CvxdeSjvHayu08t3QL3xp3Guef1jPockSkA7U0vHMzUOru+UAOcGmEdk+5+9jQT0mEfdG+lrSj7RVVfP/lYkb2y+bOzw0NuhwR6WAthf4EYH5oexEwPkK7681sqZm91OgKvum+aF9L2kl9vXPXCyuorqnnscmjSE/TLR2RZNPSu74nUBHargTCfVVzPfBDdx8D9AXGRdgXzWthZlPMrNDMCsvKyo6nL9KC372zkf9dt4v7J53Jabndgi5HRALQUuiXA9mh7ezQ46Z2AwtC25uAvAj7onkt3H2muxe4e0Fubm7LPZCorPq0gl+8vobLRvTmxnMHBF2OiASkpdBfCEwMbU8A3gzTZhow2cxSgLOAlRH2RfNa0g4OHqpj6qwV9OiazgPXna3pmSJJrKXQfxboZ2bFNFy9rzezh5q0eRK4DXgfmOPuqyPsa/paC9uuG9Kc//yf1azbuY+HbxhFTtf0oMsRkQA1O0/f3auBq5rsnt6kzTbg4ij2hXstaWcLVu/gD0s2808XncLY03sFXY6IBEzTNxLYzr1V3P1SMWf2zWL6ZcOCLkdEYoBCP0HV1zvTXyhmf3Utj980ioy01KBLEpEYoNBPUE+/u4nFH5dx31VnMiSve9DliEiMUOgnoI+2VfLAvDV87ow8bv7swKDLEZEYotBPMFU1ddwxawVZnTvx8+s1PVNEjqZVNhPMA/PWULJjL0/fdi49u2UEXY6IxBhd6SeQN0t28vS7m7jtwsFcPCwv6HJEJAYp9BNE+b5qvvdCEcP7dOeey4cHXY6IxCgN7yQAd+fuF4uprKrl2W+cR2YnTc8UkfB0pZ8A/rDkExat2ckPrhjOsD6anikikSn049zaHXv5j1c/YtzQXG69YHDQ5YhIjFPox7Hq2jpun7WCbhlpPHiDpmeKSMs0ph/HHnythI+2VfK7WwvI654ZdDkiEgd0pR+n3l5bxm/f2cgt5w/ikjN6B12OiMQJhX4c2r3/EHfNLmJIXjd+8Pkzgi5HROKIQj/OuDv3vFTMngM1zJg8StMzReS4KPTjzHNLtzB/9Q7uvnwYI07ObvkAEZFGFPpxZN3Offxk7iouOr0XX7/wlKDLEZE4pNCPE4dq67nj+eV07pTKQzfkk5Ki6Zkicvw0ZTNOPDy/hJVbK/n1V0fTO0vTM0XkxDR7pW9mmWY218yKzOwZC/PtHzO73MxKzeyd0M8wa/B7M1tiZn8ys7Rw7dqvW4nl3XXlzFy8gZvGDOSyEX2CLkdE4lhLwzs3A6Xung/kAJdGaPeUu48N/ZQAFwJp7n4ekAVMjNBOWrDnwCGmzS7ilF5d+eFVmp4pIq3TUuhPAOaHthcB4yO0u97MlprZS6FPAzuAGaHnDjXTTprh7tz78ofs2l/N45PPoUu6RuNEpHVaCv2eQEVouxLoEabNeuCH7j4G6AuMc/e17r7UzK4F0oHXw7UL9y80sylmVmhmhWVlZcffowTywrJS5q3czl0Th3FWP03PFJHWayn0y4HDaZMdetzUbmBBaHsTkAdgZlcDU4FJ7l4XqV1T7j7T3QvcvSA3Nze6XiSgjeX7+dGfVnH+qT2ZctGpQZcjIgmipdBfyN/H4ycAb4ZpMw2YbGYpwFnASjPrA3wPuNLd90Zq19riE1VNXT13zFpOp9QUHrlR0zNFpO20FPrPAv3MrJiGK/X1ZvZQkzZPArcB7wNz3H01cCsNQzivh2bqfD1COwljxoK1FJVW8MB1I+mb3TnockQkgZi7B11DRAUFBV5YWBh0GR3q/Q27mPybJdwwuj+/+GJ+0OWISBwys2XuXhDuOX0jN4ZUHKxh2uwiBvXowv2TRgRdjogkIM0BjBHuzr/O+ZAdlVW8+O0L6Jqh/zUi0vZ0pR8j5izfytzibdx56VBGDTgp6HJEJEEp9GPA5l0H+LdXVjFmcA++Ne60oMsRkQSm0A9YbV3D6plm8MiN+aRqeqaItCMNHAfsiUXr+OvmPTx+0zn0z+kSdDkikuB0pR+gZZ/s5olFa7nuM/24Ov/koMsRkSSg0A/I3qoaps5aQb+czvz4ak3PFJGOoeGdgNz/yiq2VVQx+5vn0z2zU9DliEiS0JV+AF5ZsZWXl2/lXyYMYfSgnKDLEZEkotDvYFt2H+C+OSsZPSiH744fEnQ5IpJkFPodqK7emTZ7BQ48duMo0lL1n19EOpbG9DvQU2+t44NNf+PRG/MZ0EPTM0Wk4+lSs4Ms3/w3Hl2wlqvzT+YLo/oFXY6IJCmFfgfYV13LHc+voE9WJv/+hbPQnwcWkaBoeKcD/PhPq9iy+wCzppxPdmdNzxSR4OhKv529WryNF5aV8p3xQxhzSri/Ky8i0nEU+u3o0z0HufflYvIHnMTtl5wedDkiIgr99nJ4emZdvTPjxlF00vRMEYkBzSaRmWWa2VwzKzKzZyzMHUgzu9zMSkN/AP0dMxsW7rhoXiuRzFy8gSUbdvOjq0cwuFfXoMsREQFavtK/GSh193wgB7g0Qrun3H1s6KckwnHRvlbc+7C0goffKOHKkX354uj+QZcjInJES6E/AZgf2l4EjI/Q7nozW2pmL4Wu4MMdF+1rxbUDh2qZOms5ud0z+M9rNT1TRGJLS6HfE6gIbVcC4aafrAd+6O5jgL7AuAjHRfNamNkUMys0s8KysrJo+xEz/n3uajbu2s/DX8rnpC7pQZcjInKUlkK/HMgObWeHHje1G1gQ2t4E5EU4LprXwt1nunuBuxfk5uZG0YXY8drK7Ty3dAvf/IfTuOC0XkGXIyJyjJZCfyEwMbQ9AXgzTJtpwGQzSwHOAlZGOC6a14pbOyqr+P7LxYzsl820S4cGXY6ISFgthf6zQD8zK6bhin69mT3UpM2TwG3A+8Acd18d5riFEfYlhPp6567ZRVTX1PPY5FGkp2l6pojEpmaXYXD3auCqJrunN2mzDbg4iuPC7UsIv3tnI++sK+dn143ktNxuQZcjIhKRLklbadWnFfzi9TVcNqI3k88dEHQ5IiLNUui3wsFDdUydtYIeXdN54LqzNT1TRGKeVtlshZ/+z0es27mPP/zjZ8npqumZIhL7dKV/ghZ+tINnlnzCP110CmNP1/RMEYkPCv0TsHNvFd97sZgz+2Yx/bJhQZcjIhI1hf5xqq93pr9QzP7qWh6/aRQZaalBlyQiEjWF/nH6/XubWPxxGfdddSZD8roHXY6IyHFR6B+HNdsr+dm8NVwyPI+bPzsw6HJERI6bQj9KVTV1TH1uBVmZnfj5FzU9U0Tik6ZsRumBeWso2bGXp287l17dMoIuR0TkhOhKPwpvluzk6Xc3cduFg7l4WF7Q5YiInDCFfgvK91XzvReKGda7O/dcPjzockREWkXDO81wd+55sZjKqhr+8I0xZHbS9EwRiW+60m/GH5Z8wsI1O7n3iuEM75MVdDkiIq2m0I9g7Y69/MerHzFuaC5fu2Bw0OWIiLQJhX4Y1bV13D5rBd0y0njwBk3PFJHEoTH9MB56vYSPtlXyu1sLyOueGXQ5IiJtRlf6Tby9tozfvL2Rr543iEvO6B10OSIibUqh38ju/Ye4a3YRQ/K68a9XnhF0OSIibU6hH+Lu3PNSMXsO1DBj8ihNzxSRhNRi6JtZppnNNbMiM3vGmrmraWZ3mtmC0PZkM3sn9LPTzMaZ2bfMbF2j/dlt2ZnWmPXBFuav3sHdlw9jxMkxU5aISJuK5kr/ZqDU3fOBHODScI3MbBDwtcOP3X2Wu49197HANqA49NT9h/e7e0Wrqm8j68v28ZM/r2bskF58/cJTgi5HRKTdRBP6E4D5oe1FwPgI7WYA9zbdaWanAnvc/W+hXd81s+VmNiPci5jZFDMrNLPCsrKyKMprnUO19UydtZzMTik8/KV8UlI0PVNEElc0od8TOHxFXgn0aNrAzL4MFAGrwxw/CXg1tL0MmA4UANea2eCmjd19prsXuHtBbm5uFOW1ziPzP2bl1koeuP5semdpeqaIJLZoQr8cODzInR163NRVwCXALGC0mX230XOTgLmh7c3AEnevA0qBQJesfHd9Ob9evJ6bxgzkshF9gixFRKRDRBP6C4GJoe0JwJtNG7j7l0Nj95OBZe7+JICZZQH93f3wJ4BHgLFm1hkYCKxtZf0nbM+BQ0x7vohTenblh1dpeqaIJIdoQv9ZoJ+ZFQO7gfVm9lCUr3858Eajxz8FHgDeAX7SaJy/Q7k7P5jzIbv2VzNj8jl0SdcXk0UkObSYdu5eTcPwTWPTI7TdBHyu0ePZwOxGj1cB559IoW3phWWl/M+H2/n+FcMZ2V/TM0UkeSTdl7M2le/nR39axfmn9mTKRacGXY6ISIdKqtCvqatn6vMr6JSq6ZkikpySajB7xoK1FG3Zw6++8hlOPqlz0OWIiHS4pLnSf3/DLn751jq+VNCfz4/sG3Q5IiKBSIrQrzhYw7TZRQzq0YX7J40IuhwRkcAk/PCOu3Pff69ke2UVL337ArpmJHyXRUQiSvgr/TnLt/Lnok+583OnM2rASUGXIyISqIQO/c27DvBvr6xizOAefPviIUGXIyISuIQN/dq6eu54fjlm8MiN+aRqeqaISOKO6T/55jr+unkPj990Dv1zugRdjohITEjIK/1ln+zm8YVrue6cflydf3LQ5YiIxIyEDP301FQuHNKLH1+j6ZkiIo0l5PDOyP7ZPPOPnw26DBGRmJOQV/oiIhKeQl9EJIko9EVEkohCX0QkiSj0RUSSiEJfRCSJKPRFRJKIQl9EJImYuwddQ0RmVgZ8coKH9wLK27CcIKkvsSdR+gHqS6xqTV8GuXtuuCdiOvRbw8wK3b0g6DragvoSexKlH6C+xKr26ouGd0REkohCX0QkiSRy6M8MuoA2pL7EnkTpB6gvsapd+pKwY/oiInKsRL7SFxGRJhT6IiJJJCFC38zuNLMFZtbLzN42sw/N7IHQc8fsi1Vmdneo1nlmlhevfTGzrmb2ipn9r5n9Ih7/v5hZJzP7c2g708zmmlmRmT1jDaLaF3Q/4Ji+mJn93syWmNmfzCwtXvvSaN+dZrYgtB03v2tN+9Lk/Z/eXn2J+9A3s0HA10IP7wBeBfKBK8xsaIR9McfMTgVGuPtFwDzgMeK0L8BXgCXufiEwAvg1cdQXM+sMLAMuDe26GSh193wgJ7Q/2n2BCtOXC4E0dz8PyAImEr99afr+hzjJgKZ9CfP+70879SXuQx+YAdwb2p4AzHf3euAvwPgI+2LRJUCOmS0GLgJOIX77Ug10CV0dZgIXEEd9cfeD7n42UBraNQGYH9peRKP6o9gXqDB92UHDewbgUOif8doXOPr9D3GSAWH60vT9v5F26ktch76ZfRkoAlaHdvUEKkLblUCPCPtiUS5Q5u7/QMNZfgzx25c/AlcAHwFraKg1XvsC0f9exXyf3H2tuy81s2uBdOB14rQvYd7/EKd94dj3/1jaqS9xHfrAVTScIWcBo2lYqyI79Fw2DetWlIfZF4sqgZLQ9gZgE/Hbl3uB/+Puw2n4pRxK/PYFwtca7b6YY2ZXA1OBSe5eR/z25aj3v5l9l/jtS9P3fz/aqS9xHfru/mV3HwtMpmF87JfARDNLAcYBbwILw+yLRcuAc0PbQ2j4BYjXvnQHqkLb1cB7xG9fIFRraHsCjeqPYl9MMbM+wPeAK919b2h3XPal6fvf3Z8k/O9VPPyuNX3/b6Cd+hLXoR/G48DngWLgVXdfF2FfzHH394ByM/uAhsC/hTjtCw0n32+b2XtAZ+Ba4rcvAM8C/cysGNhNwxsv2n2x5lagL/C6mb1jZl8nfvsSTlxmQNP3v7svpZ36om/kiogkkUS70hcRkWYo9EVEkohCX0QkiSj0RUSSiEJfRCSJKPRFRJLI/weJ3wZo6nCTdwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "metric[metric['classification/accuracy'].notnull()]['classification/accuracy'].plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e01e1f38",
   "metadata": {},
   "source": [
    "&emsp;&emsp;这个Accuracy其实是非常一般的，应该是咱们给的语料太少的缘故。我们先看下怎么使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9f4c21f",
   "metadata": {},
   "source": [
    "**Step3：使用**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "031123aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "对给定文本进行分类，类别包括：科技、金融、娱乐、世界、汽车、文化、军事、旅游、游戏、教育、农业、房产、社会、股票。\n",
      "\n",
      "给定文本：\n",
      "出栏一头猪亏损300元，究竟谁能笑到最后！\n",
      "类别：\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = get_prompt(lines[2][\"sentence\"])\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "9a1b49be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def complete(prompt, model, max_tokens=2):\n",
    "    response = openai.Completion.create(\n",
    "        prompt=prompt,\n",
    "        temperature=0,\n",
    "        max_tokens=max_tokens,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0,\n",
    "        model=model\n",
    "    )\n",
    "    ans = response[\"choices\"][0][\"text\"].strip(\" \\n\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "58ad560c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'社会'"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 原来的，微调之前的\n",
    "complete(prompt, \"text-davinci-003\", 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d29f339",
   "metadata": {},
   "source": [
    "&emsp;&emsp;注意咱们的prompt也要改一下。预期应该要返回一个英文单词，回忆一下训练数据里面的`completion`。我们这里为了方便说明，特意在前面用了中文标签，微调时用英文（就是希望它能直接输出英文标签，表示微调有效）。大家在实际使用时务必要统一。\n",
    "\n",
    "&emsp;&emsp;我们把模型换成刚刚微调的（就是上面返回结果中的`fine_tuned_model`字段）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "bfc23ed9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'agriculture'"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 微调后的\n",
    "prompt = lines[2][\"sentence\"] + \" ->\"\n",
    "complete(prompt, \"davinci:ft-personal-2023-04-04-14-51-29\", 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "759f8a17",
   "metadata": {},
   "source": [
    "&emsp;&emsp;咦——居然变成农业了，它和之前ChatGPT（`ChatCompletion`）的输出一样，不过依然不是我们训练数据集里的「金融」。我想这个句子看起来确实更加像农业主题，放在农业主题下应该也没问题。而且，我们的训练数据集里并没有包含这条样本。所以，这个问题不太大。\n",
    "\n",
    "&emsp;&emsp;如果我们非要它变成金融的，可以把这条数据也丢给微调接口，微调后应该就可以得到我们训练集里给的类别了。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d00ffbc7",
   "metadata": {},
   "source": [
    "&emsp;&emsp;上面我们介绍了主题分类的微调。实体抽取的微调也是类似的，它推荐的输入格式如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "cc8b15c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'prompt': '<any text, for example news article>\\n\\n###\\n\\n',\n",
       " 'completion': ' <list of entities, separated by a newline> END'}"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "{\"prompt\":\"<any text, for example news article>\\n\\n###\\n\\n\", \"completion\":\" <list of entities, separated by a newline> END\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c01f7593",
   "metadata": {},
   "source": [
    "&emsp;&emsp;举个例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "a1384214",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'prompt': 'Portugal will be removed from the UK\\'s green travel list from Tuesday, amid rising coronavirus cases and concern over a \"Nepal mutation of the so-called Indian variant\". It will join the amber list, meaning holidaymakers should not visit and returnees must isolate for 10 days...\\n\\n###\\n\\n',\n",
       " 'completion': ' Portugal\\nUK\\nNepal mutation\\nIndian variant END'}"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "{\"prompt\":\"Portugal will be removed from the UK's green travel list from Tuesday, amid rising coronavirus cases and concern over a \\\"Nepal mutation of the so-called Indian variant\\\". It will join the amber list, meaning holidaymakers should not visit and returnees must isolate for 10 days...\\n\\n###\\n\\n\", \"completion\":\" Portugal\\nUK\\nNepal mutation\\nIndian variant END\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "050b7bce",
   "metadata": {},
   "source": [
    "&emsp;&emsp;相信大家应该很容易理解，不妨自己做一些尝试。尤其是给一些专业领域的实体进行微调，对比一下微调前后的效果。\n",
    "\n",
    "&emsp;&emsp;如果大家对这块内容感兴趣，可以进一步阅读【相关文献11和12】。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28bee93e",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 智能对话"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e1bf684",
   "metadata": {},
   "source": [
    "&emsp;&emsp;智能对话，有时候也叫智能客服、对话机器人、聊天机器人等等。总之就是和用户通过聊天窗口进行交互的一种技术。传统的对话机器人（下面都这么叫了）一般包括三个大的模块：\n",
    "\n",
    "- NLU：负责对用户输入进行理解。我们在本章一开始已经提到了，主要就是意图分类+实体识别这两种技术。实际中还可能有实体关系抽取、情感识别等组件。\n",
    "- DM：Dialogue Management，对话管理。就是在拿到NLU的结果后，如何确定机器人的回复，也就是对话方向的控制。\n",
    "- NLG：Natural Language Generation，自然语言生成。就是生成最终要回复给用户的输出。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66f06889",
   "metadata": {},
   "source": [
    "&emsp;&emsp;对话机器人一般包括三种，不同类型的技术方案侧重有所不同。常见的类型如下：\n",
    "\n",
    "- 任务型机器人\n",
    "- 问答型机器人\n",
    "- 闲聊型机器人\n",
    "\n",
    "&emsp;&emsp;任务型机器人主要用来完成特定的任务，比如订机票、订餐等，这一类机器人最关键的是要获取完成任务所需的各种信息（专业术语叫：槽位）。整个对话过程其实可以看作是一个填槽过程，通过与用户不断对话获取到需要的槽位信息。比如订餐这个任务，就餐人数、就餐时间、联系人电话等就是基本信息，机器人就要想办法获取到这些信息。这里NLU就是重头，DM一般使用两种方法：模型控制或流程图控制。前者通过模型自动学习来实现流转，后者则根据意图类型进行流转控制。\n",
    "\n",
    "&emsp;&emsp;问答型机器人主要用来回复用户问题，和上一章介绍的QA有点类似，平时我们常见的客服机器人往往是这种类型。它们更重要的是Question的匹配，DM相对弱一些。\n",
    "\n",
    "&emsp;&emsp;闲聊机器人就是和客户瞎扯淡的机器人，没啥实际作用。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "005f9c55",
   "metadata": {},
   "source": [
    "&emsp;&emsp;以上是大致的分类，但真实场景中的对话机器人往往是多种功能的结合体。更加适合从主动发起/被动接受这个角度来划分。\n",
    "\n",
    "- 主动发起对话的机器人\n",
    "- 被动接受对话的机器人\n",
    "\n",
    "&emsp;&emsp;前者一般是以外呼的方式进行，营销、催款、通知等都是常见的场景。这种对话机器人一般不闲聊，电话费不允许。它们基本都是带着特定任务或目的走流程，流程走完就挂断结束。与用户的互动更多是以QA的形式完成，因为主动权在机器人手里，所以流程一遍都是固定控制的，甚至QA的数量、回答次数也会控制。\n",
    "\n",
    "&emsp;&emsp;后者一般是以网页或客户端的形式存在，绝大部分时候都是用户找上门来了，比如大部分公司首页都有个「智能客服」，就是类似功能。它们以QA为主，辅以闲聊。稍微复杂点的是上面提到的任务型机器人，需要不断收集槽位信息。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "180e8317",
   "metadata": {},
   "source": [
    "&emsp;&emsp;ChatGPT时代，智能对话机器人会有什么新变化吗？接下来，我们探讨一下这方面内容。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c960e0dc",
   "metadata": {},
   "source": [
    "&emsp;&emsp;首先，可以肯定的是ChatGPT极大的扩展了对话机器人的边界，在此之前其实有不少端到端的方案，感兴趣的读者可以略读一下【相关文献13】。之前的方案现在看来是有点复杂、繁琐，效果还不好。所以之前除了闲聊，很少有真正使用端到端方案的对话机器人。不过ChatGPT的强大In-Context能力不仅让使用更加简单（我们只需把历史对话分角色放进去就好了），而且效果也更好，除了闲聊，问答型机器人它也可以很擅长，交互更加humanable。\n",
    "\n",
    "&emsp;&emsp;我们具体来展开说说它可以做的，我们尽量聚焦当下能做到的，不做过多未来畅想。\n",
    "\n",
    "- 作为问答类产品，比如知识问答、情感咨询、心理咨询等等，完全称得上诸事不决ChatGPT。举几个简单例子，比如问它编程概念（如闭包的作用），问它如何追一个心仪的女孩子，问它怎么避免焦虑等等。它的大部分回答绝对能让你眼前一亮。这可是之前只能闲聊的机器人完全够不着的高度。\n",
    "- 作为智能客服，通过与企业知识库结合（In-Context方式和微调方式①）完全可以胜任客服工作，而且相比之前的QA类的客服，它回答更加个性化，效果也更好（如果不记得了，回忆一下前面“文档问答”的内容）。\n",
    "- 作为智能营销机器人，智能客服更加偏向被动、为用户答疑解惑的方向，营销机器人则更加主动一些，它会根据已存储的用户信息，主动向用户推荐相关产品，根据预设的目标（可以理解为槽位，长期要收集的信息）向用户发起对话。它还可以同时负责维护客户关系。\n",
    "- 作为NPC（Non-Player Character）、陪聊机器人等休闲娱乐类产品。\n",
    "- 作为教育、培训的导师，可以进行一对一教学，尤其适合语言、编程类学习。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a740f00",
   "metadata": {},
   "source": [
    "&emsp;&emsp;这些都是它确定可以做的，为什么能做？归根结底还是其大规模参数所学到的知识和具备的理解力。尤其是后者，应该是决定性的（只有知识就是Google搜索引擎）。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "425afe51",
   "metadata": {},
   "source": [
    "&emsp;&emsp;当然，并不是什么都要ChatGPT，我们千万要避免「手里有锤子，到处找钉子」的思维方式。某位哲人说过，一项新技术的出现，短期内总是被高估，长期内总是被低估。ChatGPT绝对是划时代的，但也不意味着你什么都要ChatGPT一下。比如，某些分类和实体抽取任务，之前的方法已经能达到非常好的效果，这时候就完全不需要替换。我们知道很多实际任务它并不会随着技术发展有太多变化，比如分类任务，难道你出来个新技术，分类任务就不是分类任务了吗。技术的更新会让我们的效率得到提升，也就是说做同样的任务更加简单和高效了，可以做更难的任务了，但不等于任务也会发生变化。所以，一定要理清楚这里面的关键。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8c7ad85",
   "metadata": {},
   "source": [
    "&emsp;&emsp;不过如果你是新开发，或者不了解这方面的专业知识，那就另当别论了，使用LLM的API反而可能是更好的策略。但即便如此，实际上线前还是应该考虑清楚各种细节，比如服务不可用怎么办，并发大概多少，时延要求多少，用户规模大概多少等等。我们技术方案的选型是和公司或自己的需求息息相关的，没有绝对好的方案，只有当下是否适合的方案。同时，要尽可能多考虑几步，但也不用太多（时刻谨记：「过度优化是原罪」），比如你用户只有不到1万，上来就搞个分布式的设计方案就有点坑了。但这并不妨碍你在代码和架构设计时考虑扩展性，比如数据库，我们可能用SQLite，但你代码里可不能直接和它耦合死，可以使用能同时支持其他数据库，甚至分布式数据库的ORM工具。这样虽然写起来稍微麻烦了一点点（真的是一点点），但你的代码更加清晰，而且和可能会变化的东西解耦了。这样如果日后规模上去了，数据库可以随便换，代码基本不用动。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8df1c1ce",
   "metadata": {},
   "source": [
    "&emsp;&emsp;最后，我们也应该了解ChatGPT的一些局限，除了它本身的局限（这块内容可以参考后面专门讲缺陷的章节），在工程上始终应该关注下面几个话题：\n",
    "\n",
    "- 响应时间和稳定性\n",
    "- 并发和横向可扩展性\n",
    "- 可维护性和迭代\n",
    "- 成本\n",
    "\n",
    "&emsp;&emsp;只有当这些都能满足你的期望时，才应该选择。始终记住，人才是关键，不要被任何工具绑架。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec00a6a9",
   "metadata": {},
   "source": [
    "> 注①：关于In-Context方式和微调方式的通俗解释  \n",
    "&emsp;&emsp;In-Context主要是利用ChatGPT的理解能力，把它当做超级大脑，我们把相关上下文给它，让它根据上下文回答问题，就类似前面的《文档问答》。对于不确定的问题，还可以设计兜底话术。  \n",
    "&emsp;&emsp;微调方式则是直接将自定义数据喂给ChatGPT的微调接口（现在没有开放，但理论上可行），让它学习这些自定义内容，之后直接问就好了，就像我们现在直接问它「中国的首都是哪里」，它可以正确回答一样。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15547dd6",
   "metadata": {},
   "source": [
    "&emsp;&emsp;下面，我们一起使用ChatGPT来实现一个对话机器人。设计阶段首先至少需要考虑以下一些因素（这并不包括上面提到的那些）：\n",
    "\n",
    "- 使用目的\n",
    "- 如何使用\n",
    "- 消息查询、存储\n",
    "- 消息解析\n",
    "- 实时干预\n",
    "- 更新策略"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8ca23c0",
   "metadata": {},
   "source": [
    "&emsp;&emsp;首先，咱们需要明确使用目的是什么，如上所言，不同的用途咱们要考虑的因素也不一样。简单（但很实际）起见，我们以一个「订餐机器人」为例，简单的开场白后获取用户联系方式、订餐人数、用餐时间三个信息。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81b63d1f",
   "metadata": {},
   "source": [
    "&emsp;&emsp;使用也比较简单，主要利用ChatGPT的多轮对话能力即可，这里的重点是控制上下文。不过由于任务简单，我们不用对历史记录做召回再进行对话，直接在每一轮时把已经获取的信息告诉它，同时让它继续获取其他信息，直到所有信息获取完毕为止。另外，我们可以限制一下输出Token的数量（输出文本的长度）。\n",
    "\n",
    "&emsp;&emsp;对于用户的消息（以及机器人的回复），实际中往往需要存储起来，用来做每一轮回复的历史消息召回。而且这个日后还可能有其他用途，比如使用对话记录对用户进行画像，或者用来当做训练数据等等。存储可以直接放到数据库，或传到类似ElasticSearch这样的内部搜索引擎中。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f54d4c4",
   "metadata": {},
   "source": [
    "&emsp;&emsp;消息的解析可以实时进行（并不一定要用ChatGPT）或离线进行，本案例我们需要实时在线解析。这个过程我们可以让ChatGPT在生成回复时顺便做掉。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43c46898",
   "metadata": {},
   "source": [
    "&emsp;&emsp;实时干预是应该要关注的，或者需要设计这样的模块。一方面是回复内容有时候即便做了限制，依然有可能被某些问法问到不太合适的答复；另一方面也不能排除部分恶意用户对机器人进行攻击，因此最好有干预机制的设计。这里，我们设计一个简单策略：检测用户是否提问敏感类问题，如果发现此类问题直接返回设定好的文本，不需要调用ChatGPT进行对话回复。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8e415fd",
   "metadata": {},
   "source": [
    "&emsp;&emsp;更新策略主要是对企业知识库的更新，这里由于我们使用的是In-Context能力，所以并不需要调整ChatGPT，可能需要调整Embedding接口（目前openai不支持）。此案例暂不涉及。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47e6526a",
   "metadata": {},
   "source": [
    "&emsp;&emsp;综上，我们需要先对用户输入进行敏感性检查，没问题后开始对话。同时应存储用户消息，并在每轮对话时将用户历史消息传递给接口。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "543d111d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/site-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.15) or chardet (3.0.4) doesn't match a supported version!\n",
      "  warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import openai\n",
    "\n",
    "openai.api_key = os.environ.get(\"OPENAI_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e6f54b7",
   "metadata": {},
   "source": [
    "&emsp;&emsp;先看一下敏感性检查，这个接口比较多，openai提供了一个相关的接口，国内几大厂商也有相关API。这个本身是和对话无关的。我们以openai接口为例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 197,
   "id": "476770fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "def check_risk(inp: str) -> bool:\n",
    "    safe_api = \"https://api.openai.com/v1/moderations\"\n",
    "    resp = requests.post(safe_api, json={\"input\": inp}, headers={\"Authorization\": f\"Bearer {OPENAI_API_KEY}\"})\n",
    "    data = resp.json()\n",
    "    return data[\"results\"][0][\"flagged\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "id": "d14fc920",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 198,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "check_risk(\"good\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "902e979e",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来我们考虑如何构造接口的输入，这里有两个事情要做：\n",
    "\n",
    "1. 查询历史对话记录作为上下文，简单起见我们可以只考虑上一轮。\n",
    "2. 计算输入的Token数，根据模型能接受最大Token长度和想输出的最大长度，反推上下文的最大长度，并对历史对话进行处理（如截断）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "id": "ac4e0255",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass, asdict\n",
    "from typing import List, Dict\n",
    "from datetime import datetime\n",
    "import uuid\n",
    "import json\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "36a1da7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class User:\n",
    "    \n",
    "    user_id: str\n",
    "    user_name: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "88f6822a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class ChatSession:\n",
    "    \n",
    "    user_id: str\n",
    "    session_id: str\n",
    "    cellphone: str\n",
    "    people_number: int\n",
    "    meal_time: str\n",
    "    chat_at: datetime\n",
    "\n",
    "@dataclass\n",
    "class ChatRecord:\n",
    "    \n",
    "    user_id: str\n",
    "    session_id: str\n",
    "    user_input: str\n",
    "    bot_output: str\n",
    "    chat_at: datetime"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "598eb569",
   "metadata": {},
   "source": [
    "&emsp;&emsp;上面我们首先设计了两个简单的数据结构，一个是聊天信息，一个是聊天记录，前者记录聊天基本信息，后者记录聊天记录。其中，session_id主要用来区分每一次对话，当用户点击产品页面的「开始对话」之类的按钮后，就生成一个session_id；在下次对话时再生成一个新的。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f177084",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来，我们处理核心对话逻辑，这一块主要是利用ChatGPT的能力，明确要求，把每一轮对话都喂给它。给出响应。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "id": "53d741a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask(msg):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-3.5-turbo\", \n",
    "        temperature=0.2,\n",
    "        max_tokens=100,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0,\n",
    "        messages=msg\n",
    "    )\n",
    "    ans = response.get(\"choices\")[0].get(\"message\").get(\"content\")\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5b25cbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sqlalchemy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 284,
   "id": "d7da2516",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import insert\n",
    "\n",
    "class Chatbot:\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.system_inp = \"\"\"现在你是一个订餐机器人（角色是assistant），你的目的是向用户获取手机号码、用餐人数量和用餐时间三个信息。你可以自由回复用户消息，但牢记你的目的。每一轮你需要输出给用户的回复，以及获取到的信息，信息应该以JSON方式存储，包括三个key：cellphone表示手机号码，people_number表示用餐人数，meal_time表示用餐时间储。\n",
    "\n",
    "回复格式：\n",
    "给用户的回复：{回复给用户的话}\n",
    "获取到的信息：{\"cellphone\": null, \"people_number\": null, \"meal_time\": null}\n",
    "\"\"\"\n",
    "        self.max_round = 10\n",
    "        self.slot_labels = [\"meal_time\", \"people_number\", \"cellphone\"]\n",
    "        self.reg_msg = re.compile(r\"\\n+\")\n",
    "\n",
    "\n",
    "    def check_over(self, slot_dict: dict):\n",
    "        for label in self.slot_labels:\n",
    "            if slot_dict.get(label) is None:\n",
    "                return False\n",
    "        return True\n",
    "    \n",
    "    def send_msg(self, msg: str):\n",
    "        print(f\"机器人：{msg}\")\n",
    "    \n",
    "    def chat(self, user_id: str):\n",
    "        sess_id = uuid.uuid4().hex\n",
    "        chat_at = datetime.now()\n",
    "        msg = [\n",
    "            {\"role\": \"system\", \"content\": self.system_inp},\n",
    "        ]\n",
    "        n_round = 0\n",
    "        \n",
    "        history = []\n",
    "        while True:\n",
    "            if n_round > self.max_round:\n",
    "                bot_msg = \"非常感谢您对我们的支持，再见。\"\n",
    "                self.send_msg(bot_msg)\n",
    "                break\n",
    "            \n",
    "            try:\n",
    "                bot_inp = ask(msg)\n",
    "            except Exception as e:\n",
    "                bot_msg = \"机器人出错，稍后将由人工与您联系，谢谢。\"\n",
    "                self.send_msg(bot_msg)\n",
    "                break\n",
    "            \n",
    "            tmp = self.reg_msg.split(bot_inp)\n",
    "            bot_msg = tmp[0].strip(\"给用户的回复：\")\n",
    "            self.send_msg(bot_msg)\n",
    "            if len(tmp) > 1:\n",
    "                slot_str = tmp[1].strip(\"获取到的信息：\")\n",
    "                slot = json.loads(slot_str)\n",
    "                print(f\"\\tslot: {slot}\")\n",
    "            n_round += 1\n",
    "            \n",
    "            if self.check_over(slot):\n",
    "                break\n",
    "\n",
    "            user_inp = input()\n",
    "            \n",
    "            msg += [\n",
    "                {\"role\": \"assistant\", \"content\": bot_inp},\n",
    "                {\"role\": \"user\", \"content\": user_inp},\n",
    "            ]\n",
    "            \n",
    "            record = ChatRecord(user_id, sess_id, bot_inp, user_inp, datetime.now())\n",
    "            history.append(record)\n",
    "            \n",
    "            if check_risk(user_inp):\n",
    "                break\n",
    "        \n",
    "        chat_sess = ChatSession(user_id, sess_id, **slot, chat_at=chat_at)\n",
    "        self.store(history, chat_sess)\n",
    "    \n",
    "    \n",
    "    def store(self, history: List[ChatRecord], chat: ChatSession):\n",
    "        with SessionLocal.begin() as sess:\n",
    "            q = insert(\n",
    "                chat_record_table\n",
    "            ).values(\n",
    "                [asdict(v) for v in history]\n",
    "            )\n",
    "            sess.execute(q)\n",
    "        with SessionLocal.begin() as sess:\n",
    "            q = insert(\n",
    "                chat_session_table\n",
    "            ).values(\n",
    "                [asdict(chat)]\n",
    "            )\n",
    "            sess.execute(q)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79f6f800",
   "metadata": {},
   "source": [
    "&emsp;&emsp;接下来，我们把两张表建好（User表这里就不建了）。**注意：建表只要一次。**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 285,
   "id": "0ffbcb2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import Table, Column, Integer, String, DateTime, Text, MetaData, SmallInteger\n",
    "from sqlalchemy import create_engine\n",
    "from sqlalchemy.orm import sessionmaker\n",
    "import os\n",
    "\n",
    "\n",
    "db_file = \"chatbot.db\"\n",
    "\n",
    "if os.path.exists(db_file):\n",
    "    os.remove(db_file)\n",
    "\n",
    "engine = create_engine(f\"sqlite:///{db_file}\")\n",
    "SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)\n",
    "\n",
    "metadata_obj = MetaData()\n",
    "\n",
    "chat_record_table = Table(\n",
    "    \"chat_record_table\",\n",
    "    metadata_obj,\n",
    "    Column(\"id\", Integer, primary_key=True),\n",
    "    Column(\"user_id\", String(64), index=True),\n",
    "    Column(\"session_id\", String(64), index=True),\n",
    "    Column(\"user_input\", Text),\n",
    "    Column(\"bot_output\", Text),\n",
    "    Column(\"chat_at\", DateTime),\n",
    ")\n",
    "\n",
    "chat_session_table = Table(\n",
    "    \"chat_session_table\",\n",
    "    metadata_obj,\n",
    "    Column(\"id\", Integer, primary_key=True),\n",
    "    Column(\"user_id\", String(64), index=True),\n",
    "    Column(\"session_id\", String(64), index=True),\n",
    "    Column(\"cellphone\", String(16)),\n",
    "    Column(\"people_number\", SmallInteger),\n",
    "    Column(\"meal_time\", String(32)),\n",
    "    Column(\"chat_at\", DateTime),\n",
    ")\n",
    "\n",
    "metadata_obj.create_all(engine, checkfirst=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 286,
   "id": "6dfa400b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-rw-r--r--  1 Yam  staff  28672 Apr  8 00:00 ./chatbot.db\n"
     ]
    }
   ],
   "source": [
    "!ls -la ./chatbot.db"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "089530dd",
   "metadata": {},
   "source": [
    "&emsp;&emsp;现在我们进行简单的尝试："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "780511c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pnlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 287,
   "id": "10fd3019",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "机器人：请问您的手机号码是多少呢？\n",
      "\tslot: {'cellphone': None, 'people_number': None, 'meal_time': None}\n",
      "我的手机是13788889999\n",
      "机器人：好的，您的手机号码是13788889999，请问用餐人数是几位呢？\n",
      "\tslot: {'cellphone': '13788889999', 'people_number': None, 'meal_time': None}\n",
      "我们一共五个人\n",
      "机器人：好的，您们一共五个人，最后，请问您们的用餐时间是什么时候呢？\n",
      "\tslot: {'cellphone': '13788889999', 'people_number': 5, 'meal_time': None}\n",
      "稍等我问一下啊\n",
      "机器人：好的，没问题，我等您的消息。\n",
      "好了，明天下午7点，谢谢\n",
      "机器人：好的，您们的用餐时间是明天下午7点，我们已经为您记录好了，请问还有其他需要帮助的吗？\n",
      "\tslot: {'cellphone': '13788889999', 'people_number': 5, 'meal_time': '明天下午7点'}\n"
     ]
    }
   ],
   "source": [
    "import pnlp\n",
    "nick = \"长琴\"\n",
    "user = User(pnlp.generate_uuid(nick), nick)\n",
    "chatbot = Chatbot()\n",
    "chatbot.chat(user.user_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "241fdc85",
   "metadata": {},
   "source": [
    "&emsp;&emsp;查表看看刚刚的记录："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "id": "e8cf915a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sqlite3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "id": "05372159",
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_table(table: str):\n",
    "    con = sqlite3.connect(\"chatbot.db\")\n",
    "    cur = con.cursor()\n",
    "    q = cur.execute(f\"SELECT * FROM {table}\")\n",
    "    return q.fetchall()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 290,
   "id": "4d8c2755",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(1,\n",
       "  'dc3be3b3516555d3b0b6a77a1d9c7e82',\n",
       "  '05a88a8e3db8490eacf14b8bb9800fcc',\n",
       "  '13788889999',\n",
       "  5,\n",
       "  '明天下午7点',\n",
       "  '2023-04-08 00:00:34.618232')]"
      ]
     },
     "execution_count": 290,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_table(\"chat_session_table\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 291,
   "id": "c352ab58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(1,\n",
       "  'dc3be3b3516555d3b0b6a77a1d9c7e82',\n",
       "  '05a88a8e3db8490eacf14b8bb9800fcc',\n",
       "  '给用户的回复：请问您的手机号码是多少呢？\\n获取到的信息：{\"cellphone\": null, \"people_number\": null, \"meal_time\": null}',\n",
       "  '我的手机是13788889999',\n",
       "  '2023-04-08 00:00:47.498172'),\n",
       " (2,\n",
       "  'dc3be3b3516555d3b0b6a77a1d9c7e82',\n",
       "  '05a88a8e3db8490eacf14b8bb9800fcc',\n",
       "  '给用户的回复：好的，您的手机号码是13788889999，请问用餐人数是几位呢？\\n获取到的信息：{\"cellphone\": \"13788889999\", \"people_number\": null, \"meal_time\": null}',\n",
       "  '我们一共五个人',\n",
       "  '2023-04-08 00:01:18.694161'),\n",
       " (3,\n",
       "  'dc3be3b3516555d3b0b6a77a1d9c7e82',\n",
       "  '05a88a8e3db8490eacf14b8bb9800fcc',\n",
       "  '给用户的回复：好的，您们一共五个人，最后，请问您们的用餐时间是什么时候呢？\\n获取到的信息：{\"cellphone\": \"13788889999\", \"people_number\": 5, \"meal_time\": null}',\n",
       "  '稍等我问一下啊',\n",
       "  '2023-04-08 00:01:40.296970'),\n",
       " (4,\n",
       "  'dc3be3b3516555d3b0b6a77a1d9c7e82',\n",
       "  '05a88a8e3db8490eacf14b8bb9800fcc',\n",
       "  '好的，没问题，我等您的消息。',\n",
       "  '好了，明天下午7点，谢谢',\n",
       "  '2023-04-08 00:02:15.839735')]"
      ]
     },
     "execution_count": 291,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_table(\"chat_record_table\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8600d248",
   "metadata": {},
   "source": [
    "&emsp;&emsp;上面我们实现了一个非常简易的任务机器人，虽然简易，但我们其实很容易就能发现，NLU、DM和NLG三个模块已经完全不需要了。唯一的不足可能是接口反应有点慢，从对话来看其实并没有太多问题。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af7c5ca2",
   "metadata": {},
   "source": [
    "&emsp;&emsp;另外，需要再次对几个问题进行强调，以便大家可以更好地构建应用。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b505af5",
   "metadata": {},
   "source": [
    "&emsp;&emsp;第一点，当要支持的对话轮次非常多时（比如培训、面试这样的场景），则需要实时将每一轮的对话索引起来，每一轮先召回所有历史对话中相关的topN轮作为上下文（正如我们在文档问答中那样）。然后让ChatGPT根据这些上下文对用户进行回复。这样理论上我们是可以支持无限轮的。召回的过程其实就是一个回忆的过程，这里可以优化的点或者说想象的空间很大。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3a44c22",
   "metadata": {},
   "source": [
    "&emsp;&emsp;第二点，在传递message参数给ChatGPT时，由于有长度限制，有时候上下文中遇到特别长回复那种轮次，可能会导致只能传几轮（甚至一轮就耗光长度了）。根据ChatGPT自己的说法：当历史记录非常长时，我们确实可能只能利用其中的一小部分来生成回答。为了应对这种情况，通常会使用一些技术来选择最相关的历史记录，以便在生成回答时使用。例如，可能会使用一些关键词提取技术，识别出历史记录中最相关的信息，并将其与当前的输入一起使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a61763bd",
   "metadata": {},
   "source": [
    "&emsp;&emsp;此外，还可能会使用一些摘要技术来对历史记录进行压缩和精简，以便在生成回答时只使用最重要的信息。另外，还可以使用一些记忆机制，例如注意力机制，以便在历史记录中选择最相关的信息。虽然这些技术可以帮助模型在历史记录很长时选择最相关的信息，但在某些情况下，历史记录仍然可能过于复杂，导致模型难以正确理解和处理。在这种情况下，可能需要使用其他技术来限制历史记录的长度或提供其他方面的辅助信息，以便模型可以更好地理解和回答用户的问题。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da5faae1",
   "metadata": {},
   "source": [
    "&emsp;&emsp;另外，根据ChatGPT的说法，在生成回复时，它也会使用一些技术来限制输出长度，例如截断输出或者使用一些策略来生成更加简洁的回答。当然，用户也可以使用特定的输入限制或规则来帮助缩短回答。总之，尽可能地在输出长度和回答质量之间进行平衡。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "829b9f1f",
   "metadata": {},
   "source": [
    "&emsp;&emsp;第三点，充分考虑安全性，根据实际情况合理设计架构（但不要过度设计）。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e97a47f6",
   "metadata": {},
   "source": [
    "&emsp;&emsp;最后，值得一提的是，上面只是利用了ChatGPT的一丢丢功能，大家完全可以结合自己的业务，或者大开脑洞，开发更多有用、有趣的产品和应用。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "189px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
